Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor custom FPx cast #363

Merged
merged 11 commits into from
Jun 17, 2024
6 changes: 3 additions & 3 deletions test/prototype/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device):
x = torch.randn(256, 64, device=device)

expected = to_tc_float6_e3m2(x)
actual = torch.compile(to_tc_float6_e3m2)(x)
actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
Expand All @@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device):
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias):

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
actual = torch.compile(fp6_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
202 changes: 202 additions & 0 deletions torchao/prototype/custom_fp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
# 1. No encodings are reserved for special values (+/-inf, NaN).
# 2. When downcasting from FP32 to FPx,
# - Rounding mode is round to nearest, ties to even.
# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
# magnitude (sign is preserved).

import torch
from torch import Tensor


def _n_ones(n: int) -> int:
return (1 << n) - 1


EBITS_F32, MBITS_F32 = 8, 23
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)


def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have a docblock?

assert x.dtype == torch.float
assert 1 + ebits + mbits <= 8

# calculate constants
exp_bias = _n_ones(ebits - 1)
max_int = _n_ones(ebits + mbits)
sign_mask = 1 << (ebits + mbits)

# TODO document this better
magic_adder = _n_ones(MBITS_F32 - mbits - 1)

# all E bits and M bits are 1s
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))

# E bits = 1, M bits = 0
min_normal = 2 ** (1 - exp_bias)

denorm_exp = (
# exp bias conversion between formats
(F32_EXP_BIAS - exp_bias)
# mantissa length difference between formats
+ (MBITS_F32 - mbits)
# add one to encoded exponent for denormalized numbers
+ 1
)
denorm_mask_int = denorm_exp << MBITS_F32

# reinterpret int32 as float32
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)

# save the sign
# Note that we have torch.uint32, but some ops like cpu bit shifts
# do not work on it. So, we stay in int32.
x = x.view(torch.int32)
sign = x & 0x80000000

# set everything to positive, will add sign back at the end
x = x ^ sign

# TODO: can the branch floating point comparisons below be done without
# converting to float? probably but need to verify
x = x.view(torch.float)

# rewrite saturate/denorm/norm branches without explicit data dependent
# control flow, to be more compiler friendly
saturate_mask = x >= max_normal
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))

#
# branch 1: saturate to max val - handled later in the code which combines
# the branches
#

#
# branch 2: to conversion to denormal as well as rounding up to normal
#
denormal_x = x + denorm_mask_float
denormal_x = denormal_x.view(torch.int32)
denormal_x -= denorm_mask_int
denormal_x = denormal_x.to(torch.uint8)

#
# branch 3: stay in normal range, adjust the exponent and round
#
normal_x = x.view(torch.int32)
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
# update exponent, rounding bias part 1
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
# take the bits!
normal_x = normal_x >> (MBITS_F32 - mbits)
normal_x = normal_x.to(torch.uint8)

#
# combine the branches
#
x = torch.full_like(x, max_int, dtype=torch.uint8)
x = torch.where(denormal_mask, denormal_x, x)
x = torch.where(normal_mask, normal_x, x)

# add sign back
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
sign_lp = sign_lp.to(torch.uint8)
# Right shift of a negative signed integer can fill the least significant
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
# doesn't have an uint32 dtype, we mask out these bits to get just the
# f4 sign bit
sign_lp = sign_lp & sign_mask
x = x | sign_lp

return x.to(torch.uint8)


def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""
TODO(future): check if LUT for everything is faster than bit shifting,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this comment still relevant?

maybe add a docblock?

Copy link
Collaborator Author

@gau-nernst gau-nernst Jun 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using LUT for everything in dequant might be faster, like current NF4 implementation. I haven't benchmarked so I'm not sure.
I didn't add a docblock here since I think this is kinda an internal function. But a simple doc won't hurt. Will add some doc for this and quant function above. I already added a short description for these 2 functions at the top of the file.

especially for fp4.
"""
assert x.dtype == torch.uint8
assert 1 + ebits + mbits <= 8

sign_mask = 1 << (ebits + mbits)
exp_bias = _n_ones(ebits - 1)
mantissa_mask = _n_ones(mbits)

# save the sign
sign_lp = x & sign_mask

# set everything to positive, will add sign back at the end
x_pos = x ^ sign_lp

#
# 1. Calculate zero mask
#
zero_mask = x_pos == 0

#
# 2. Calculate the denormal path mask
#
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))

#
# 3. Calculate the normal path
#

# calculate the new exponent and shift it to bits 2:9 of the result
exp_biased_lp = x_pos >> mbits
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32

# shift the mantissa to bits 10:32 of the result
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
result = exp_biased_f32 | mantissa_f32

#
# 4. Add the zero and denormal casts to the already casted normal path
#
result[zero_mask] = 0

denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS

# fast path.
# without this, performance for FP4_E2M1 is slower by 2x
if mbits == 1:
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32

else:
# iterate over all possible values of mantissa
# i=0, j=1
# i=1, j=10,11
# i=2, j=100,101,110,111
# and so on
for i in range(mbits):
for mantissa_cmp in range(1 << i, 1 << (i+1)):
# left shift mantissa until it overflows (create an implicit 1)
# subtract exponent by the same amount
left_shift = mbits - i
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32

# we can update this in-place since the values won't overlap
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32

result = torch.where(denormal_mask, mantissa_lp_int32, result)

# add sign back
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
result = result | sign_f32

return result.view(torch.float)
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/benchmarks/bench_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def run(profile_folder: Optional[str] = None):
data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32)

if not use_fp4_custom_triton_dequant_kernel:
quant = torch.compile(MXTensor.to_mx)
dequant = torch.compile(data_lp.to_dtype)
quant = torch.compile(MXTensor.to_mx, fullgraph=True)
dequant = torch.compile(data_lp.to_dtype, fullgraph=True)
else:
# As of 2024-04, torch.compile didn't work with the
# handwritten triton kernel,
Expand Down
Loading
Loading