Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] improve eager numerics for dynamic scales and gets on par with torch.compile #904

Merged
merged 43 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
6bf0f5c
[float8] improve eager numerics for dynamic scales
weifengpy Sep 19, 2024
553687f
leave torch.linalg.vector_norm for another PR
weifengpy Sep 19, 2024
19a592d
cuda
weifengpy Sep 19, 2024
218290e
remove _data and investigate
weifengpy Sep 19, 2024
24ec914
remove _data comment
weifengpy Sep 19, 2024
c099486
upcast to float32 is enough
weifengpy Sep 21, 2024
b93ffc8
explain why float32
weifengpy Sep 21, 2024
ebff416
_data parity
weifengpy Sep 21, 2024
8978ab2
handle sm8.9
weifengpy Sep 21, 2024
f17dc12
fix transformer unit test
weifengpy Sep 22, 2024
511c751
print if error
weifengpy Sep 26, 2024
9becda1
Add tutorial for trainable tensor subclass (#908)
andrewor14 Sep 20, 2024
e4fdca9
Introducing 1-bit quantization for Llama in torchchat (#910)
vaishnavi17 Sep 20, 2024
0cd4d37
Rename Floating point to fp8 (#909)
jainapurva Sep 20, 2024
014558d
[float8] fix typo in bitwise_identical unit test (#918)
weifengpy Sep 23, 2024
3267402
Adding example for quantized tensor + tensor parallelism (#785)
jerryzh168 Sep 23, 2024
1e07eff
rename cuda mode -> gpu mode (#925)
msaroufim Sep 24, 2024
ebdeed0
Add workaround to recover the perf for quantized vit in torch.compile…
jerryzh168 Sep 24, 2024
09ffa22
clean up device checks in float8 unit test files (#923)
vkuzo Sep 24, 2024
0b8dd85
[low-bit optim] Change 8-bit and FP8 optim block size from 2048 to 25…
gau-nernst Sep 24, 2024
87faf04
Float8 autoquant weight only (#866)
jainapurva Sep 24, 2024
3a9fdb0
Fix failing FP6 benchmark (#931)
tobiasvanderwerff Sep 25, 2024
fc6c393
Remove two if statements in fp8 padding (#935)
y-sq Sep 25, 2024
0043ace
[Distributed] Improve sharding example (#937)
kwen2501 Sep 25, 2024
ab3435c
Add composable QAT quantizer (#938)
andrewor14 Sep 25, 2024
a05a40f
resolve conflict with latest main
weifengpy Sep 26, 2024
334891b
Add torchchat quantizer
metascroy Sep 25, 2024
c706139
Add compile tests to test suite (#906)
jerryzh168 Sep 26, 2024
93554c0
Fix up CMakeLists and reorganize some code locations
metascroy Sep 26, 2024
efd9bb9
[float8] all-reduce amax on dp mesh instead of global pg (#933)
weifengpy Sep 26, 2024
85126cc
int8 dynamic quant + bsr support (#821)
jcaip Sep 26, 2024
a5a426e
fixing some issues with our support for 70/405B models (#941)
HDCharles Sep 26, 2024
e7270f1
Update INT8 mixed-precision training test to be less flaky (#950)
gau-nernst Sep 26, 2024
352685c
Add executorch parallel
metascroy Sep 26, 2024
168cfe9
Merge branch 'weifengpy-dynamic_scale_numerics' into dynamic_scale_nu…
weifengpy Sep 26, 2024
5900c3e
Merge branch 'main' into dynamic_scale_numerics
weifengpy Sep 26, 2024
37e1479
test CI
weifengpy Sep 26, 2024
2efde49
better comment on why upcasting
weifengpy Sep 26, 2024
8c04f4f
control seed
weifengpy Sep 26, 2024
04b229b
move unit test to test_compile
weifengpy Sep 26, 2024
8b7c2ef
fix typo
weifengpy Sep 26, 2024
9346afd
float64 upcasting after allreduce
weifengpy Sep 27, 2024
3d0da20
use LinearMMConfig
weifengpy Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

import torch
import torch.nn as nn
from torchao.float8.float8_scaling_utils import (
hp_tensor_to_float8_dynamic,
)

from torchao.utils import TORCH_VERSION_AT_LEAST_2_5

Expand Down Expand Up @@ -604,6 +607,40 @@ def test_small_amax_float16(self, float8_dtype):
x = torch.tensor([target_amax], dtype=torch.float16, device="cuda")
scale = tensor_to_scale(x, float8_dtype)
assert not torch.any(torch.isinf(scale))

@unittest.skipIf(
not is_cuda_8_9,
"CUDA not available",
)
@pytest.mark.parametrize(
"dtype",
[
torch.float32,
torch.bfloat16,
torch.float16,
],
)
def test_dynamic_scale_parity(self, dtype: torch.dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: move to test_compile.py since this is testing compile vs eager?

scaling_type_weight = ScalingType.DYNAMIC
torch.manual_seed(0)
hp_tensor = torch.randn(768, 32, device="cuda", dtype=dtype)
float8_config = Float8LinearConfig(
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
)
float8_eager = hp_tensor_to_float8_dynamic(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
float8_compile = torch.compile(hp_tensor_to_float8_dynamic)(
hp_tensor,
torch.float8_e4m3fn,
float8_config,
gemm_input_role=GemmInputRole.WEIGHT,
)
assert torch.equal(float8_eager._scale, float8_compile._scale)
Copy link
Contributor Author

@weifengpy weifengpy Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without the PR, the numerics looks like following
eager _scale=106.5000 vs compile _scale=106.1925...

after, eager is also 106.1925...

assert torch.testing.assert_close(float8_eager._data, float8_compile._data)


class TestFloat8LinearUtils(unittest.TestCase):
Expand Down
5 changes: 1 addition & 4 deletions test/float8/test_fsdp2/fsdp2_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,7 @@ def check_parity_no_mp(
):
precompute_float8_dynamic_scale_for_fsdp(model)

if compile_transformer_block:
test_cls.assertEqual(losses[0], losses[1], atol=1e-4, rtol=1e-4)
else:
test_cls.assertEqual(losses[0], losses[1])
test_cls.assertEqual(losses[0], losses[1])


def check_parity_bf16_mp(
Expand Down
3 changes: 2 additions & 1 deletion torchao/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def forward(

DTensor Invariant: DTensor must always be the outer most tensor subclass
"""
tensor_scaled = tensor * scale
# scale is float32 thus upcasting tensor to match
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make this comment contain the context? something like

# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically upcasted to `float32` to multiply with the scale
# In order to match numerics between eager and compile, we upcast manually here.

tensor_scaled = tensor.to(torch.float32) * scale
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without upcasting, the eager numeric is like -157.00000000000000000000, compile is like -157.06507873535156250000

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile upcast tensor ahead, see tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32) in following output code

@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 24576
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), None).to(tl.float32)

bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype)

if isinstance(bits_fp8, DTensor):
Expand Down
2 changes: 2 additions & 0 deletions torchao/float8/float8_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def amax_to_scale(
float8_dtype: The float8 dtype.
orig_dtype: The original dtype of the tensor.
"""
# _scaled_mm requires float32 scale
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we describe in more detail why we are upcasting here

amax = amax.to(torch.float64)
Copy link
Contributor Author

@weifengpy weifengpy Sep 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

upcast amax in amax_to_scale instead of tensor_to_amax for 2 reasons

  • we can still do bfloat16 all-reduce for amax
  • safer to delayed scaling as it won't change dtype for amax_buffer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you share why the upcasting happens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can look into inductor more on how it achieved fp64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.compile actually upcasts to float32 with tl.load(in_ptr0 + (x0), None).to(tl.float32). Upcasting to float64 further help because torch.compile and eager shows different numerics for 1.0 / float32 (but same numeric for float64)

The float32 numeric difference can be verified with

import torch
def upcast_reciprocal(inp: torch.Tensor):
    return inp.reciprocal()
inp = torch.full([], 0.00817871093750000000, device="cuda", dtype=torch.float32)
eager_scale = upcast_reciprocal(inp)
compile_scale = torch.compile(upcast_reciprocal)(inp)
fp64_ground_truth = inp.to(torch.float64).reciprocal()
assert torch.equal(eager_scale, compile_scale), f"{eager_scale=} vs {compile_scale=}, {fp64_ground_truth=}"

if float8_dtype in FP8_TYPES:
res = torch.finfo(float8_dtype).max / torch.clamp(amax, min=EPS)
else:
Expand Down
6 changes: 3 additions & 3 deletions torchao/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,17 @@ def precompute_float8_dynamic_scale_for_fsdp(module: nn.Module) -> None:
return

# inf-norm is equivalent to max(abs(w))
max_weights = torch._foreach_norm(weights, ord=math.inf) # Partial
max_weights = torch._foreach_norm(weights, ord=math.inf, dtype=torch.float64) # Partial
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add comment to describe upcasting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improved comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype=torch.float64 only changes the accumulation dtype? if there is no noticeable cost to this, I wonder if we should be doing this in more places 🤔

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question! Actually I just updated the code to do _foreach_norm in original precision and do float64 upcasting before calculating scales. That ensures consistent implementation between prcompute and float8_utils.amax_to_scale

back to your question, I checked ForeachReduceOp.cu and it's dispatching to lpnorm_cleanup<scalar_t, NormType::LInf, out_t>. Not sure what's inside lpnorm_cleanup. But inf-norm is just max(abs) so not sure if they accumulate numerics
https://github.com/pytorch/pytorch/blob/a28b40fa74470058ca57d77652b9601bece2f4d5/aten/src/ATen/native/cuda/ForeachReduceOp.cu#L534-L535C19

amax_tensor = torch.stack(max_weights) # Partial
# clamp is dispatched through DTensor
# it will issue a single all-reduce
amax_tensor = torch.clamp(amax_tensor, EPS) # Replicate
scale_tensor = torch.finfo(torch.float8_e4m3fn).max / amax_tensor # Replicate
if amax_tensor.dtype is torch.float16:
scale_tensor = torch.clamp(scale_tensor, max=torch.finfo(torch.float16).max)
local_scale_tensor = scale_tensor.to_local()
local_scale_tensor = scale_tensor.to_local().to(torch.float32)
for i, float8_linear in enumerate(float8_linears):
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i].to(torch.float32)
float8_linear.weight._local_tensor._precomputed_scale = local_scale_tensor[i]


# FSDP pads its local tensor on dim-0. The subclass should be preserved such
Expand Down
Loading