Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT REVIEW] debug float8 all-gather numerics #873

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
e17155c
debug numerics
weifengpy Sep 11, 2024
544a09c
print all mismatch
weifengpy Sep 11, 2024
dd357d6
init linears
weifengpy Sep 11, 2024
a880ec2
repro with single linear
weifengpy Sep 12, 2024
1ff11c9
bitwise equal
weifengpy Sep 12, 2024
2059918
bitwise equal
weifengpy Sep 12, 2024
e8385b2
remove change on orignal test
weifengpy Sep 12, 2024
8109bbe
restore
weifengpy Sep 12, 2024
72f8d7e
restore
weifengpy Sep 12, 2024
760dff9
clean repro
weifengpy Sep 12, 2024
c39d2d8
make baseline FSDP too
weifengpy Sep 12, 2024
a3baf9f
delayed linear are on par
weifengpy Sep 13, 2024
095a2c5
remove uncessary change
weifengpy Sep 13, 2024
a263329
numeric on par with float32
weifengpy Sep 15, 2024
653444c
add fully_shard to float32
weifengpy Sep 16, 2024
0a4b91d
compile float8Linear on-par
weifengpy Sep 16, 2024
fc22626
compile + transformer root works
weifengpy Sep 16, 2024
05f467e
bfloat16 works fine for torch.compile + float8linear, basic root
weifengpy Sep 16, 2024
c584dca
float8 with/o precompute + bf16/fp32 on par on single float8liner and
weifengpy Sep 17, 2024
569f862
bf16/fp32 + float8linear bitwise equal, precompute=False
weifengpy Sep 18, 2024
003dfc3
bf16/fp32 parity for full transformer, precompute=False
weifengpy Sep 18, 2024
37fb7a5
float8linear bitwise on par between eager and compile
weifengpy Sep 19, 2024
8450886
eager is on par with float64 numerics. fixing torch.compile
weifengpy Sep 19, 2024
f4688ae
trying to replicate bfloat16 in test_base unit test
weifengpy Sep 19, 2024
432230a
test_base numerics on reciprocal
weifengpy Sep 21, 2024
78b8a8d
_data parity
weifengpy Sep 21, 2024
fc6063c
transformer on par becuase of upcasting to float64
weifengpy Sep 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,18 @@
import itertools
import random
import re
import math
import unittest
import warnings

import pytest

import torch
import torch.nn as nn
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
hp_tensor_to_float8_dynamic_debug,
)

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -53,7 +58,7 @@
is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9)

def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
assert torch.all(a._data == b._data).item(), "scales are not identical"
assert torch.all(a._scale == b._scale).item(), "scales are not identical"
assert torch.all(a._data == b._data).item(), "data is not identical"
return True

Expand Down Expand Up @@ -604,6 +609,36 @@ def test_small_amax_float16(self, float8_dtype):
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
scale = tensor_to_scale(x, float8_dtype)
assert not torch.any(torch.isinf(scale))

@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.bfloat16,
torch.float16,
],
)
def test_float8_data_parity(self, dtype: torch.dtype):
dtype = torch.bfloat16
scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(0)
hp_tensor = torch.randn(24, 2, device="cuda", dtype=dtype)
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
torch.equal(float8_eager._data, float8_compile._data)


class TestFloat8LinearUtils(unittest.TestCase):
Expand Down
146 changes: 139 additions & 7 deletions test/float8/test_fsdp2/fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,145 @@
sync_float8_amax_and_scale_history,
)
from torchao.float8.fsdp_utils import precompute_float8_dynamic_scale_for_fsdp
import os


def check_parity_no_mp(
@contextlib.contextmanager
def enable_profiling(enable=False):
if not enable:
torch_profiler = contextlib.nullcontext()
yield None
else:
trace_dir = "./profilers"
rank = torch.distributed.get_rank()
def trace_handler(prof):
curr_trace_dir_name = "iteration_" + str(prof.step_num)
curr_trace_dir = os.path.join(trace_dir, curr_trace_dir_name)
if not os.path.exists(curr_trace_dir):
os.makedirs(curr_trace_dir, exist_ok=True)
prof.export_chrome_trace(f"{curr_trace_dir}/rank{rank}_trace.json")
torch.distributed.barrier()
if not os.path.exists(trace_dir):
os.makedirs(trace_dir, exist_ok=True)
warmup, active = 1, 2
wait = 1
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active),
on_trace_ready=trace_handler,
record_shapes=True,
) as torch_profiler:
yield torch_profiler


def run_training_loop(
test_cls,
model: nn.Module,
optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
steps,
float8_config: Float8LinearConfig,
dtype: torch.dtype,
seed: int,
precompute: bool = False,
):
torch._dynamo.reset()
losses = []
param_sums = []
grad_sums = []
torch.manual_seed(seed)
with enable_profiling(False) as torch_profiler:
for iter_idx in range(steps):
# local_inp = torch.rand(16, 16, 768, device="cuda", dtype=dtype)
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
loss = model(local_inp).sum()
losses.append(loss)
loss.backward()
param_sum = torch.concat(list(x.full_tensor().reshape(-1) for x in model.parameters())).sum()
grad_sum = torch.concat(list(x.grad.full_tensor().reshape(-1) for x in model.parameters())).sum()
# param_sum = torch.stack(list(x.reshape(-1) for x in model.parameters())).sum()
# grad_sum = torch.stack(list(x.grad.reshape(-1) for x in model.parameters())).sum()
param_sums.append(param_sum)
grad_sums.append(grad_sum)
if linear_requires_sync(float8_config):
sync_float8_amax_and_scale_history(model)
optim.step()
if (
precompute
and float8_config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)
if torch_profiler:
torch_profiler.step()
return losses, param_sums, grad_sums


def compare_numerics(
test_cls,
losses1: List[torch.Tensor],
param_sums1: List[torch.Tensor],
grad_sums1: List[torch.Tensor],
losses2: List[torch.Tensor],
param_sums2: List[torch.Tensor],
grad_sums2: List[torch.Tensor],
):
assert len(losses1) == len(losses2)
steps = len(losses1)
for i in range(steps):
# test_cls.assertEqual(losses1[i], losses2[i], f"loss different at {i}: {losses1[i]} vs {losses2[i]}")
# test_cls.assertEqual(param_sums1[i], param_sums2[i], f"param_sum different at {i}: {param_sums1[i]} vs {param_sums2[i]}")
# test_cls.assertEqual(grad_sums1[i], grad_sums2[i], f"grad_sum different at {i}: {grad_sums1[i]} vs {grad_sums2[i]}")
assert torch.equal(param_sums1[i], param_sums2[i]), f"param_sum different at {i}: {param_sums1[i]} vs {param_sums2[i]}"
assert torch.equal(losses1[i], losses2[i]), f"loss different at {i}: {losses1[i]} vs {losses2[i]}"
assert torch.equal(grad_sums1[i], grad_sums2[i]), f"grad_sum different at {i}: {grad_sums1[i]} vs {grad_sums2[i]}"


def check_parity_compile(
test_cls,
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
fsdp_model: nn.Module,
fsdp_optim: torch.optim.Optimizer,
local_inp: torch.Tensor,
precompute: bool = False,
config: Optional[Float8LinearConfig] = None,
):
ref_losses: List[torch.Tensor] = []
ref_param_sums: List[torch.Tensor] = []
ref_grad_sums: List[torch.Tensor] = []
for model, optim in ((ref_model, ref_optim), (fsdp_model, fsdp_optim)):
torch._dynamo.reset()
for iter_idx in range(1000):
optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
loss = model(local_inp).sum()
loss.backward()

if linear_requires_sync(config):
sync_float8_amax_and_scale_history(model)

param_sum = torch.stack([param.sum() for param in model.parameters()]).sum()
grad_sum = torch.stack([param.grad.sum() for param in model.parameters()]).sum()
if model is ref_model:
ref_losses.append(loss)
ref_param_sums.append(param_sum)
ref_grad_sums.append(grad_sum)
else:
assert torch.equal(loss, ref_losses[iter_idx]), f"loss different at {iter_idx}: {loss} vs {ref_losses[iter_idx]}"
assert torch.equal(param_sum, ref_param_sums[iter_idx]), f"param_sum different at {iter_idx}: {param_sum} vs {ref_param_sums[iter_idx]}"
assert torch.equal(grad_sum, ref_grad_sums[iter_idx]), f"grad_sum different at {iter_idx}: {grad_sum} vs {ref_grad_sums[iter_idx]}"
optim.step()
if (
model is fsdp_model
and precompute
and config.cast_config_weight.scaling_type is ScalingType.DYNAMIC
):
precompute_float8_dynamic_scale_for_fsdp(model)


def check_parity_eager_ddp_no_mp(
test_cls,
ref_model: nn.Module,
ref_optim: torch.optim.Optimizer,
Expand All @@ -23,7 +159,6 @@ def check_parity_no_mp(
local_inp: torch.Tensor,
precompute: bool = False,
config: Optional[Float8LinearConfig] = None,
compile_transformer_block: bool = False,
):
# TODO(before land): reorder args and make config not optional
for iter_idx in range(10):
Expand All @@ -48,13 +183,10 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
else:
test_cls.assertEqual(losses[0], losses[1])
assert torch.equal(losses[0], losses[1]), f"loss different at {iter_idx}: {losses[0]} vs {losses[1]}"


def check_parity_bf16_mp(
def check_parity_eager_ddp_bf16_mp(
test_cls,
ref_model: nn.Module,
ref_model_bf16: nn.Module,
Expand Down
Loading
Loading