Skip to content

Commit

Permalink
Refactor custom FPx cast (pytorch#363)
Browse files Browse the repository at this point in the history
* refactor custom fp cast

* add dequant

* small formating

* compile with fullgraph=True

* add fullgraph=true

* undo

* add another version

* fast path for mbits=1

* add back docstring
  • Loading branch information
gau-nernst authored Jun 17, 2024
1 parent 246fad7 commit 3b848bd
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 348 deletions.
6 changes: 3 additions & 3 deletions test/prototype/test_fp6_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device):
x = torch.randn(256, 64, device=device)

expected = to_tc_float6_e3m2(x)
actual = torch.compile(to_tc_float6_e3m2)(x)
actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@parametrize("device", _DEVICES)
Expand All @@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device):
x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device)

expected = from_tc_float6_e3m2(x)
actual = torch.compile(from_tc_float6_e3m2)(x)
actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down Expand Up @@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias):

x = torch.randn(N, IC, device=device, dtype=torch.half)
expected = fp6_linear(x)
actual = torch.compile(fp6_linear)(x)
actual = torch.compile(fp6_linear, fullgraph=True)(x)
torch.testing.assert_close(actual, expected)

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
Expand Down
227 changes: 227 additions & 0 deletions torchao/prototype/custom_fp_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
# 1. No encodings are reserved for special values (+/-inf, NaN).
# 2. When downcasting from FP32 to FPx,
# - Rounding mode is round to nearest, ties to even.
# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
# magnitude (sign is preserved).

import torch
from torch import Tensor


def _n_ones(n: int) -> int:
return (1 << n) - 1


EBITS_F32, MBITS_F32 = 8, 23
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)


def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""Convert FP32 numbers to sub-byte floating point numbers with the given
number of exponent and mantissa bits.
Input: torch.Tensor of dtype torch.float
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Note: there are no special values (NaN, inf) support in this code. Values
outside the representable range of FPx after rounding are clamped to the
maximum FPx magnitude (sign is preserved).
Code below is an adaptation of https://fburl.com/code/ciwofcg4
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5
"""
assert x.dtype == torch.float
assert 1 + ebits + mbits <= 8

# calculate constants
exp_bias = _n_ones(ebits - 1)
max_int = _n_ones(ebits + mbits)
sign_mask = 1 << (ebits + mbits)

# TODO document this better
magic_adder = _n_ones(MBITS_F32 - mbits - 1)

# all E bits and M bits are 1s
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))

# E bits = 1, M bits = 0
min_normal = 2 ** (1 - exp_bias)

denorm_exp = (
# exp bias conversion between formats
(F32_EXP_BIAS - exp_bias)
# mantissa length difference between formats
+ (MBITS_F32 - mbits)
# add one to encoded exponent for denormalized numbers
+ 1
)
denorm_mask_int = denorm_exp << MBITS_F32

# reinterpret int32 as float32
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)

# save the sign
# Note that we have torch.uint32, but some ops like cpu bit shifts
# do not work on it. So, we stay in int32.
x = x.view(torch.int32)
sign = x & 0x80000000

# set everything to positive, will add sign back at the end
x = x ^ sign

# TODO: can the branch floating point comparisons below be done without
# converting to float? probably but need to verify
x = x.view(torch.float)

# rewrite saturate/denorm/norm branches without explicit data dependent
# control flow, to be more compiler friendly
saturate_mask = x >= max_normal
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal)
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask))

#
# branch 1: saturate to max val - handled later in the code which combines
# the branches
#

#
# branch 2: to conversion to denormal as well as rounding up to normal
#
denormal_x = x + denorm_mask_float
denormal_x = denormal_x.view(torch.int32)
denormal_x -= denorm_mask_int
denormal_x = denormal_x.to(torch.uint8)

#
# branch 3: stay in normal range, adjust the exponent and round
#
normal_x = x.view(torch.int32)
# resulting mantissa is odd
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1
# update exponent, rounding bias part 1
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder
normal_x += val_to_add
# rounding bias part 2
normal_x += mant_odd
# take the bits!
normal_x = normal_x >> (MBITS_F32 - mbits)
normal_x = normal_x.to(torch.uint8)

#
# combine the branches
#
x = torch.full_like(x, max_int, dtype=torch.uint8)
x = torch.where(denormal_mask, denormal_x, x)
x = torch.where(normal_mask, normal_x, x)

# add sign back
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits)
sign_lp = sign_lp.to(torch.uint8)
# Right shift of a negative signed integer can fill the least significant
# bits with either 1s or 0s, depending on the implementation. Since PyTorch
# doesn't have an uint32 dtype, we mask out these bits to get just the
# f4 sign bit
sign_lp = sign_lp & sign_mask
x = x | sign_lp

return x.to(torch.uint8)


# TODO(future): check if LUT for everything is faster than bit shifting,
# especially for fp4 (only 2^4=16 unique values).
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
"""Convert sub-byte floating point numbers with the given number of exponent
and mantissa bits to FP32.
Input: torch.Tensor of dtype uint8, where the bit encoding is stored
in the least significant bits. e.g.
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding
Output: torch.Tensor of dtype fp32 with the dequantized value
"""
assert x.dtype == torch.uint8
assert 1 + ebits + mbits <= 8

sign_mask = 1 << (ebits + mbits)
exp_bias = _n_ones(ebits - 1)
mantissa_mask = _n_ones(mbits)

# save the sign
sign_lp = x & sign_mask

# set everything to positive, will add sign back at the end
x_pos = x ^ sign_lp

#
# 1. Calculate zero mask
#
zero_mask = x_pos == 0

#
# 2. Calculate the denormal path mask
#
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0))

#
# 3. Calculate the normal path
#

# calculate the new exponent and shift it to bits 2:9 of the result
exp_biased_lp = x_pos >> mbits
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32

# shift the mantissa to bits 10:32 of the result
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32)
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits)
result = exp_biased_f32 | mantissa_f32

#
# 4. Add the zero and denormal casts to the already casted normal path
#
result[zero_mask] = 0

denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS

# fast path.
# without this, performance for FP4_E2M1 is slower by 2x
if mbits == 1:
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32

else:
# iterate over all possible values of mantissa
# i=0, j=1
# i=1, j=10,11
# i=2, j=100,101,110,111
# and so on
for i in range(mbits):
for mantissa_cmp in range(1 << i, 1 << (i+1)):
# left shift mantissa until it overflows (create an implicit 1)
# subtract exponent by the same amount
left_shift = mbits - i
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32

# we can update this in-place since the values won't overlap
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32

result = torch.where(denormal_mask, mantissa_lp_int32, result)

# add sign back
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
result = result | sign_f32

return result.view(torch.float)
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/benchmarks/bench_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def run(profile_folder: Optional[str] = None):
data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32)

if not use_fp4_custom_triton_dequant_kernel:
quant = torch.compile(MXTensor.to_mx)
dequant = torch.compile(data_lp.to_dtype)
quant = torch.compile(MXTensor.to_mx, fullgraph=True)
dequant = torch.compile(data_lp.to_dtype, fullgraph=True)
else:
# As of 2024-04, torch.compile didn't work with the
# handwritten triton kernel,
Expand Down
Loading

0 comments on commit 3b848bd

Please sign in to comment.