Skip to content

Commit

Permalink
fix bug with float8 training + FSDP2 + TP (pytorch#1327)
Browse files Browse the repository at this point in the history
Summary:

The combination of float8 training + FSDP2 + TP recently broke, fixing:

1. add a test case so we have this covered in CI.
2. fix the test case, by ensuring we check for `Float8Tensor` properly
   when it is wrapped in `DTensor`.

Note 1: most of the code in `distributed_utils.py` was dead code from
before we switched to DTensor, so I deleted it in this PR.
Note 2: we already have extensive testing for FSDP2 and TP/SP in
separate files. I chose to create a new file for testing those two
features together to keep complexity and test runtime manageable.
Note 3: we really should make these distributed test cases run in CI,
right now it's still local testing only
Note 4: there are a couple of future follow-ups which would be
interesting:
- in FSDP2 with float8-all-gather, perhaps we should return
  DTensor(Float8Tensor) instead of Float8Tensor, to stay consistent with
  how FSDP2 wraps weights without float8-all-gather
- in DTensor, it would be nice if `isinstance(t, Float8Tensor)` returned
  True if `t` is a DTensor wrapping a Float8Tensor - food for thought
  for composability. Having this would enable us to simplify some of
  the float8 modeling code.

Test Plan:

```
// tests added in this PR
./test/float8/test_dtensor.sh

// all tests
./test/float8/test_everything.sh

// torchtitan command fails before this PR and passes after
with-proxy CONFIG_FILE="./train_configs/debug_model.toml" ./run_llama_train.sh --float8.enable_float8_linear --float8.enable_fsdp_float8_all_gather --training.tensor_parallel_degree 2
```

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
vkuzo authored and sunjiweiswift committed Nov 25, 2024
1 parent 8acd8e7 commit 05ed2e6
Show file tree
Hide file tree
Showing 8 changed files with 177 additions and 144 deletions.
25 changes: 1 addition & 24 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -49,6 +47,7 @@
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel


def setup_distributed():
Expand All @@ -59,28 +58,6 @@ def setup_distributed():
return device_mesh


class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.ffn = FeedForward()

def forward(self, x):
return self.ffn(x)


def _test_scaled_mm(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = e4m3_dtype
Expand Down
4 changes: 4 additions & 0 deletions test/float8/test_dtensor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False";
exit
fi

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py

# integration smoke tests for FSDP2 + TP
NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py
121 changes: 121 additions & 0 deletions test/float8/test_fsdp2_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Test numerics of manually defined float16 TP vs float8 TP of toy models
Note: for now, this does not run in CI.
TODO(future): make this run in CI
"""

import copy
import os

import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torchao.testing.float8.dtensor_utils import ToyModel


def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))

# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
device_mesh = init_device_mesh(
"cuda",
(world_size // 2, 2),
mesh_dim_names=("dp", "tp"),
)
# seed must be the same in all processes
torch.manual_seed(1)
return device_mesh


def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False
):
device = mesh.device_type

config = Float8LinearConfig(
emulate=True,
enable_fsdp_float8_all_gather=True,
)

toy_model = ToyModel().to(device)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)

# apply TP
tp_model = parallelize_module(
tp_model,
mesh["tp"],
{
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(),
},
)

if compile:
tp_model = torch.compile(tp_model)

# apply FSDP
fsdp_config = {"mesh": mesh["dp"]}
tp_model = fully_shard(tp_model, **fsdp_config)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
torch.cuda.synchronize()

# TODO(future PR): test numerics, and add more cases


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)


if __name__ == "__main__":
# float8 only works on CUDA H100 so we only test cuda and we follow
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()

tests = [
_test_fp8_mlp_tensor_parallelism_eager,
_test_fp8_mlp_tensor_parallelism_compile,
]

for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()
117 changes: 16 additions & 101 deletions torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,110 +3,25 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any

import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_group
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

# from float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import Float8Tensor

# additional differentiable distributed primitives for SP which are not in
# the Fairscale codebase


def _gather_along_first_dim(input_: torch.Tensor):
# same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67,
# but gather along first dim instead of last dim
group = get_model_parallel_group()

# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_

# Size and dimension.
first_dim = 0
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)

# If the input is a float8 tensor, we need to do the transformation on the
# inner tensor and then return a new wrapper.
def _transform(t):
# tensors must be contiguous for all_gather to work
input_contig = t.contiguous()

tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)]
tensor_list[rank] = input_contig
torch.distributed.all_gather(tensor_list, input_contig, group=group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=first_dim).contiguous()
return output

if isinstance(input_, Float8Tensor):
new_data = input_._data
new_data = new_data.view(torch.int8)
new_data = _transform(new_data)
new_data = new_data.view(input_._data.dtype)
output = Float8Tensor(new_data, input_._scale, input_._orig_dtype)
else:
output = _transform(input_)

return output


def _reduce_scatter(ctx: Any, input_: torch.Tensor):
group = get_model_parallel_group()
world_size = torch.distributed.get_world_size(group)

assert input_.shape[0] % world_size == 0
output_shape = (input_.shape[0] // world_size, *input_.shape[1:])
output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype)

torch.distributed.reduce_scatter_tensor(output, input_, group=group)
return output


def _split_along_first_dim(input_: torch.Tensor):
# this is needed for testing

# like fairscale.nn.model_parallel.mappings._split, but
# along the first dim instead of last dim

group = get_model_parallel_group()
local_rank = torch.distributed.get_rank(group)
world_size = torch.distributed.get_world_size(group)

assert input_.shape[0] % world_size == 0
input_list = torch.split(input_, input_.shape[0] // world_size)
return input_list[local_rank]


class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)

@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter(ctx, grad_output)


class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _reduce_scatter(ctx, input_)

@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)


class _AllGatherFwSplitBw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)

@staticmethod
def backward(ctx, grad_output):
return _split_along_first_dim(grad_output)
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8, works if the local
tensor is wrapped in DTensor.
"""
if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False
5 changes: 3 additions & 2 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.utils.checkpoint as checkpoint

from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
Expand Down Expand Up @@ -469,7 +470,7 @@ def cast_input_to_float8(
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if isinstance(weight, Float8Tensor):
if tensor_already_casted_to_fp8(weight):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
Expand Down Expand Up @@ -497,7 +498,7 @@ def cast_weight_to_float8_t(
is_amax_initialized: bool,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(weight, Float8Tensor):
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import torch

from torchao.float8.config import ScalingGranularity
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
hp_tensor_and_scale_to_float8,
tensor_already_casted_to_fp8,
)
from torchao.float8.float8_utils import (
amax_history_to_scale,
Expand Down
16 changes: 0 additions & 16 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, NamedTuple, Optional

import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

from torchao.float8.float8_utils import (
Expand Down Expand Up @@ -121,21 +120,6 @@ def choose_scaled_mm_config(
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8
"""
if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False


@torch._dynamo.allow_in_graph
class _ToFloat8ConstrFunc(torch.autograd.Function):
"""
Expand Down
Loading

0 comments on commit 05ed2e6

Please sign in to comment.