Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug with float8 training + FSDP2 + TP #1327

Merged
merged 1 commit into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions test/float8/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -49,6 +47,7 @@
)
from torchao.float8.float8_utils import e4m3_dtype, tensor_to_scale
from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
from torchao.testing.float8.dtensor_utils import ToyModel


def setup_distributed():
Expand All @@ -59,28 +58,6 @@ def setup_distributed():
return device_mesh


class FeedForward(nn.Module):
"""MLP based model"""

def __init__(self):
super(FeedForward, self).__init__()
self.w1 = nn.Linear(16, 32, bias=False)
self.w2 = nn.Linear(16, 32, bias=False)
self.out_proj = nn.Linear(32, 16, bias=False)

def forward(self, x):
return self.out_proj(F.silu(self.w1(x)) * self.w2(x))


class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.ffn = FeedForward()

def forward(self, x):
return self.ffn(x)


def _test_scaled_mm(mesh: DeviceMesh, size=16):
device = mesh.device_type
fp8_dtype = e4m3_dtype
Expand Down
4 changes: 4 additions & 0 deletions test/float8/test_dtensor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,8 @@ if python -c 'import torch;print(torch.cuda.is_available())' | grep -q "False";
exit
fi

# integration tests for TP/SP
NCCL_DEBUG=WARN torchrun --nproc_per_node 2 test/float8/test_dtensor.py

# integration smoke tests for FSDP2 + TP
NCCL_DEBUG=WARN torchrun --nproc_per_node 4 test/float8/test_fsdp2_tp.py
121 changes: 121 additions & 0 deletions test/float8/test_fsdp2_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
"""
Test numerics of manually defined float16 TP vs float8 TP of toy models

Note: for now, this does not run in CI.
TODO(future): make this run in CI
"""

import copy
import os

import pytest
import torch

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

if not TORCH_VERSION_AT_LEAST_2_5:
pytest.skip("Unsupported PyTorch version", allow_module_level=True)

from torch.distributed._composable.fsdp import fully_shard
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
from torch.distributed.tensor.parallel import parallelize_module
from tqdm import tqdm

from torchao.float8 import Float8LinearConfig
from torchao.float8.float8_linear_utils import convert_to_float8_training
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Float8RowwiseParallel,
)
from torchao.testing.float8.dtensor_utils import ToyModel


def setup_distributed():
world_size = int(os.environ.get("WORLD_SIZE", -1))

# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
device_mesh = init_device_mesh(
"cuda",
(world_size // 2, 2),
mesh_dim_names=("dp", "tp"),
)
# seed must be the same in all processes
torch.manual_seed(1)
return device_mesh


def _test_fp8_mlp_tensor_parallelism_base(
mesh: DeviceMesh, size=16, compile: bool = False
):
device = mesh.device_type

config = Float8LinearConfig(
emulate=True,
enable_fsdp_float8_all_gather=True,
)

toy_model = ToyModel().to(device)

tp_model = copy.deepcopy(toy_model)
tp_model = convert_to_float8_training(tp_model, config=config)

# apply TP
tp_model = parallelize_module(
tp_model,
mesh["tp"],
{
"ffn.w1": Float8ColwiseParallel(),
"ffn.w2": Float8ColwiseParallel(),
"ffn.out_proj": Float8RowwiseParallel(),
},
)

if compile:
tp_model = torch.compile(tp_model)

# apply FSDP
fsdp_config = {"mesh": mesh["dp"]}
tp_model = fully_shard(tp_model, **fsdp_config)

x_fp32 = torch.rand(size, size * 2, size, device=device, requires_grad=False)
x_fp32_tp_input = x_fp32.clone()

tp_out = tp_model(x_fp32_tp_input)
tp_out.sum().backward()
torch.cuda.synchronize()

# TODO(future PR): test numerics, and add more cases


def _test_fp8_mlp_tensor_parallelism_eager(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=False)


def _test_fp8_mlp_tensor_parallelism_compile(mesh: DeviceMesh, size=16):
_test_fp8_mlp_tensor_parallelism_base(mesh, size, compile=True)


if __name__ == "__main__":
# float8 only works on CUDA H100 so we only test cuda and we follow
# other test files to not use TestCase but instead just add the test
# cases in the main func.
device_mesh = setup_distributed()

tests = [
_test_fp8_mlp_tensor_parallelism_eager,
_test_fp8_mlp_tensor_parallelism_compile,
]

for test in tqdm(tests, desc="Running tests"):
try:
test(device_mesh)
except Exception as e:
print(f"Test {test.__name__} failed with error: {e}")
raise e

torch.distributed.destroy_process_group()
117 changes: 16 additions & 101 deletions torchao/float8/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,110 +3,25 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any

import torch
from fairscale.nn.model_parallel.initialize import get_model_parallel_group
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

# from float8_tensor import Float8Tensor
from torchao.float8.float8_tensor import Float8Tensor

# additional differentiable distributed primitives for SP which are not in
# the Fairscale codebase


def _gather_along_first_dim(input_: torch.Tensor):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all of this was dead code, deleting

# same as https://github.com/facebookresearch/fairscale/blob/main/fairscale/nn/model_parallel/mappings.py#L67,
# but gather along first dim instead of last dim
group = get_model_parallel_group()

# Bypass the function if we are using only 1 GPU.
if torch.distributed.get_world_size(group=group) == 1:
return input_

# Size and dimension.
first_dim = 0
rank = torch.distributed.get_rank(group=group)
world_size = torch.distributed.get_world_size(group=group)

# If the input is a float8 tensor, we need to do the transformation on the
# inner tensor and then return a new wrapper.
def _transform(t):
# tensors must be contiguous for all_gather to work
input_contig = t.contiguous()

tensor_list = [torch.empty_like(input_contig) for _ in range(world_size)]
tensor_list[rank] = input_contig
torch.distributed.all_gather(tensor_list, input_contig, group=group)

# Note: torch.cat already creates a contiguous tensor.
output = torch.cat(tensor_list, dim=first_dim).contiguous()
return output

if isinstance(input_, Float8Tensor):
new_data = input_._data
new_data = new_data.view(torch.int8)
new_data = _transform(new_data)
new_data = new_data.view(input_._data.dtype)
output = Float8Tensor(new_data, input_._scale, input_._orig_dtype)
else:
output = _transform(input_)

return output


def _reduce_scatter(ctx: Any, input_: torch.Tensor):
group = get_model_parallel_group()
world_size = torch.distributed.get_world_size(group)

assert input_.shape[0] % world_size == 0
output_shape = (input_.shape[0] // world_size, *input_.shape[1:])
output = torch.empty(*output_shape, device=input_.device, dtype=input_.dtype)

torch.distributed.reduce_scatter_tensor(output, input_, group=group)
return output


def _split_along_first_dim(input_: torch.Tensor):
# this is needed for testing

# like fairscale.nn.model_parallel.mappings._split, but
# along the first dim instead of last dim

group = get_model_parallel_group()
local_rank = torch.distributed.get_rank(group)
world_size = torch.distributed.get_world_size(group)

assert input_.shape[0] % world_size == 0
input_list = torch.split(input_, input_.shape[0] // world_size)
return input_list[local_rank]


class _AllGatherFloat8FwReduceScatterBw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)

@staticmethod
def backward(ctx, grad_output):
return _reduce_scatter(ctx, grad_output)


class _ReduceScatterFwAllGatherFloat8Bw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _reduce_scatter(ctx, input_)

@staticmethod
def backward(ctx, grad_output):
return _gather_along_first_dim(grad_output)


class _AllGatherFwSplitBw(torch.autograd.Function):
@staticmethod
def forward(ctx, input_):
return _gather_along_first_dim(input_)

@staticmethod
def backward(ctx, grad_output):
return _split_along_first_dim(grad_output)
def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8, works if the local
tensor is wrapped in DTensor.
"""
if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False
5 changes: 3 additions & 2 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch.utils.checkpoint as checkpoint

from torchao.float8.config import Float8LinearConfig, ScalingGranularity, ScalingType
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_scaling_utils import (
NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
Expand Down Expand Up @@ -469,7 +470,7 @@ def cast_input_to_float8(
return input_fp8

def get_weight_scale(self, weight: torch.Tensor) -> Optional[torch.Tensor]:
if isinstance(weight, Float8Tensor):
if tensor_already_casted_to_fp8(weight):
return None
if self.scaling_type_weight is ScalingType.DELAYED:
scale_fn_name = self.config.delayed_scaling_config.scale_fn_name
Expand Down Expand Up @@ -497,7 +498,7 @@ def cast_weight_to_float8_t(
is_amax_initialized: bool,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(weight, Float8Tensor):
if tensor_already_casted_to_fp8(weight):
return weight.t()
weight_fp8 = hp_tensor_and_scale_to_float8(
weight,
Expand Down
2 changes: 1 addition & 1 deletion torchao/float8/float8_scaling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
import torch

from torchao.float8.config import ScalingGranularity
from torchao.float8.distributed_utils import tensor_already_casted_to_fp8
from torchao.float8.float8_tensor import (
Float8Tensor,
GemmInputRole,
LinearMMConfig,
hp_tensor_and_scale_to_float8,
tensor_already_casted_to_fp8,
)
from torchao.float8.float8_utils import (
amax_history_to_scale,
Expand Down
16 changes: 0 additions & 16 deletions torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from typing import Dict, NamedTuple, Optional

import torch
import torch.distributed._functional_collectives as funcol
from torch.distributed._tensor import DTensor

from torchao.float8.float8_utils import (
Expand Down Expand Up @@ -121,21 +120,6 @@ def choose_scaled_mm_config(
raise AssertionError(f"unexpected a_role {a_role} and b_role {b_role}")


def tensor_already_casted_to_fp8(tensor: torch.Tensor) -> bool:
"""
Check if the tensor is already casted to fp8
"""
if isinstance(tensor, Float8Tensor):
return True
elif isinstance(tensor, DTensor):
# TODO: shall we stick to public API and directly use tensor.to_local() here?
return tensor_already_casted_to_fp8(tensor._local_tensor)
elif isinstance(tensor, funcol.AsyncCollectiveTensor):
return tensor_already_casted_to_fp8(tensor.elem)

return False


@torch._dynamo.allow_in_graph
class _ToFloat8ConstrFunc(torch.autograd.Function):
"""
Expand Down
Loading
Loading