From 47f7bc18d3fed4b9f08758ef28ca27e948fb0c40 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 14 Jun 2024 22:03:50 +0800 Subject: [PATCH 1/9] refactor custom fp cast --- torchao/prototype/custom_fp_utils.py | 120 ++++++++++++++++ torchao/prototype/mx_formats/custom_cast.py | 152 +------------------- 2 files changed, 124 insertions(+), 148 deletions(-) create mode 100644 torchao/prototype/custom_fp_utils.py diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py new file mode 100644 index 0000000000..d6e1cce6e3 --- /dev/null +++ b/torchao/prototype/custom_fp_utils.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import struct + +import torch +from torch import Tensor + + +def _n_ones(n: int) -> int: + return (1 << n) - 1 + + +EBITS_F32, MBITS_F32 = 8, 23 +F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) + + +def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int): + assert x.dtype == torch.float + + # calculate constants + exp_bias = _n_ones(ebits - 1) + max_int = _n_ones(ebits + mbits) + sign_mask = 1 << (ebits + mbits) + + # TODO document this better + magic_adder = _n_ones(MBITS_F32 - mbits - 1) + + # all E bits and M bits are 1s + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # E bits = 1, M bits = 0 + min_normal = 2 ** (1 - exp_bias) + + denorm_exp = ( + # exp bias conversion between formats + (F32_EXP_BIAS - exp_bias) + # mantissa length difference between formats + + (MBITS_F32 - mbits) + # add one to encoded exponent for denormalized numbers + + 1 + ) + denorm_mask_int = denorm_exp << MBITS_F32 + + # reinterpret int32 as float32 in Python + # see https://stackoverflow.com/a/34446112/1058521 + denorm_mask_float = struct.unpack("!f", struct.pack("!I", denorm_mask_int))[0] + + # save the sign + # Note that we have torch.uint32, but some ops like cpu bit shifts + # do not work on it. So, we stay in int32. + x = x.view(torch.int32) + sign = x & 0x80000000 + + # set everything to positive, will add sign back at the end + x = x ^ sign + + # TODO: can the branch floating point comparisons below be done without + # converting to float? probably but need to verify + x = x.view(torch.float) + + # rewrite saturate/denorm/norm branches without explicit data dependent + # control flow, to be more compiler friendly + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and( + torch.logical_not(saturate_mask), x < min_normal + ) # noqa: E501 + normal_mask = torch.logical_not( + torch.logical_or(saturate_mask, denormal_mask) + ) # noqa: E501 + + # + # branch 1: saturate to max val - handled later in the code which combines + # the branches + # + + # + # branch 2: to conversion to denormal as well as rounding up to normal + # + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # + # branch 3: stay in normal range, adjust the exponent and round + # + normal_x = x.view(torch.int32) + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + + # + # combine the branches + # + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + + # add sign back + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + x = x | sign_lp + + return x.to(torch.uint8) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 91aea9275a..e6d313953e 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -12,6 +12,7 @@ from torch.utils._triton import has_triton from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -27,17 +28,8 @@ E8M0_EXPONENT_NAN_VAL, F32_EXP_BIAS, F4_E2M1_EXP_BIAS, - F4_E2M1_MAX, - F4_E2M1_MAX_INT, - F4_E2M1_MIN_NORMAL, F6_E2M3_EXP_BIAS, - F6_E2M3_MAX, - F6_E2M3_MAX_INT, - F6_E2M3_MIN_NORMAL, F6_E3M2_EXP_BIAS, - F6_E3M2_MAX, - F6_E3M2_MAX_INT, - F6_E3M2_MIN_NORMAL, ) @@ -133,125 +125,13 @@ def get_bits(x: torch.Tensor) -> str: ZERO_POINT_FIVE_BITS_F32 = 0x3F000000 -def _f32_to_f4_or_f6_unpacked( - x, - max_normal, - min_normal, - denorm_mask_float, - denorm_mask_int, - ebits, - mbits, - exp_bias, - magic_adder, - max_int, - sign_mask, -): - """ - Input: torch.Tensor of dtype torch.float - Output: torch.Tensor of dtype torch.uint8, - fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding - fp6: bits 0-1 empty and bits 2-7 in the fp6_e2m3 or fp6_e3m2 encoding - - Note: there is no special values (NaN, inf) support in this code as the - OCP spec does not define special values for fp6 and fp4 dtypes. - - Code below is an adaptation of https://fburl.com/code/ciwofcg4 for f4/f6 - - Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 - Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 - """ - assert x.dtype == torch.float - - # save the sign - # Note that we have torch.uint32, but some ops like cpu bit shifts - # do not work on it. So, we stay in int32. - x = x.view(torch.int32) - sign = x & 0x80000000 - - # set everything to positive, will add sign back at the end - x = x ^ sign - - # TODO: can the branch floating point comparisons below be done without - # converting to float? probably but need to verify - x = x.view(torch.float) - - # rewrite saturate/denorm/norm branches without explicit data dependent - # control flow, to be more compiler friendly - saturate_mask = x >= max_normal - denormal_mask = torch.logical_and( - torch.logical_not(saturate_mask), x < min_normal - ) # noqa: E501 - normal_mask = torch.logical_not( - torch.logical_or(saturate_mask, denormal_mask) - ) # noqa: E501 - - # - # branch 1: saturate to max val - handled later in the code which combines - # the branches - # - - # - # branch 2: to conversion to denormal as well as rounding up to normal - # - denormal_x = x + denorm_mask_float - denormal_x = denormal_x.view(torch.int32) - denormal_x -= denorm_mask_int - denormal_x = denormal_x.to(torch.uint8) - - # - # branch 3: stay in normal range, adjust the exponent and round - # - normal_x = x.view(torch.int32) - # resulting mantissa is odd - mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 - # update exponent, rounding bias part 1 - val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder - normal_x += val_to_add - # rounding bias part 2 - normal_x += mant_odd - # take the bits! - normal_x = normal_x >> (MBITS_F32 - mbits) - normal_x = normal_x.to(torch.uint8) - - # - # combine the branches - # - x = torch.full_like(x, max_int, dtype=torch.uint8) - x = torch.where(denormal_mask, denormal_x, x) - x = torch.where(normal_mask, normal_x, x) - - # add sign back - sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) - sign_lp = sign_lp.to(torch.uint8) - # Right shift of a negative signed integer can fill the least significant - # bits with either 1s or 0s, depending on the implementation. Since PyTorch - # doesn't have an uint32 dtype, we mask out these bits to get just the - # f4 sign bit - sign_lp = sign_lp & sign_mask - x = x | sign_lp - - return x.to(torch.uint8) - - def f32_to_f4_unpacked(x): """ Input: torch.Tensor of dtype torch.float Output: torch.Tensor of dtype torch.uint8, with bits 0-3 empty and bits 4-7 in fp4_e2m1 """ - return _f32_to_f4_or_f6_unpacked( - x, - F4_E2M1_MAX, - F4_E2M1_MIN_NORMAL, - DENORM_F32TOF4_MASK_FLOAT, - DENORM_F32TOF4_MASK_INT, - EBITS_F4_E2M1, - MBITS_F4_E2M1, - F4_E2M1_EXP_BIAS, - MAGIC_ADDER_F4_E2M1, - F4_E2M1_MAX_INT, - SIGN_MASK_F4, - ) + return _f32_to_fpx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f32_to_f6_e2m3_unpacked(x): @@ -260,19 +140,7 @@ def f32_to_f6_e2m3_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e2m3 """ - return _f32_to_f4_or_f6_unpacked( - x, - F6_E2M3_MAX, - F6_E2M3_MIN_NORMAL, - DENORM_F32TOF6_E2M3_MASK_FLOAT, - DENORM_F32TOF6_E2M3_MASK_INT, - EBITS_F6_E2M3, - MBITS_F6_E2M3, - F6_E2M3_EXP_BIAS, - MAGIC_ADDER_F6_E2M3, - F6_E2M3_MAX_INT, - SIGN_MASK_F6_E2M3, - ) + return _f32_to_fpx_unpacked(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f32_to_f6_e3m2_unpacked(x): @@ -281,19 +149,7 @@ def f32_to_f6_e3m2_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e3m2 """ - return _f32_to_f4_or_f6_unpacked( - x, - F6_E3M2_MAX, - F6_E3M2_MIN_NORMAL, - DENORM_F32TOF6_E3M2_MASK_FLOAT, - DENORM_F32TOF6_E3M2_MASK_INT, - EBITS_F6_E3M2, - MBITS_F6_E3M2, - F6_E3M2_EXP_BIAS, - MAGIC_ADDER_F6_E3M2, - F6_E3M2_MAX_INT, - SIGN_MASK_F6_E3M2, - ) + return _f32_to_fpx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2) def _f4_or_f6_unpacked_to_f32(x: torch.Tensor, lp_dtype_name: str): From da1761125ad0ade95799fa9587a0da1aa58ca026 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 14 Jun 2024 23:31:30 +0800 Subject: [PATCH 2/9] add dequant --- torchao/prototype/custom_fp_utils.py | 74 +++++++++++- torchao/prototype/mx_formats/custom_cast.py | 126 +------------------- 2 files changed, 77 insertions(+), 123 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index d6e1cce6e3..c360415d7f 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -16,10 +16,12 @@ def _n_ones(n: int) -> int: EBITS_F32, MBITS_F32 = 8, 23 F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) +F32_MANTISSA_MASK = _n_ones(MBITS_F32) -def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int): +def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: assert x.dtype == torch.float + assert 1 + ebits + mbits <= 8 # calculate constants exp_bias = _n_ones(ebits - 1) @@ -118,3 +120,73 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int): x = x | sign_lp return x.to(torch.uint8) + + +def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: + """ + TODO(future): check if LUT for everything is faster than bit shifting, + especially for fp4. + """ + assert x.dtype == torch.uint8 + assert 1 + ebits + mbits <= 8 + + sign_mask = 1 << (ebits + mbits) + exp_bias = _n_ones(ebits - 1) + mantissa_mask = _n_ones(mbits) + + # save the sign + sign_lp = x & sign_mask + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_lp + + # + # 1. Calculate zero mask + # + zero_mask = x_pos == 0 + + # + # 2. Calculate the denormal path mask + # + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) + + # + # 3. Calculate the normal path + # + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_lp = x_pos >> mbits + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) + result = exp_biased_f32 | mantissa_f32 + + # + # 4. Add the zero and denormal casts to the already casted normal path + # + result[zero_mask] = 0 + # Note: for now the denormal path cast is written for readability and + # numerical correctness. There is likely a way to optimize the performance, + # I just haven't had time to look into it. + + denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS + + # for denormal numbers, we shift left until mantissa bits overflow + # when doing so, we also need to subtract exponent by the same amount. + for i in range(mbits): + mask = torch.logical_and(denormal_mask, mantissa_lp_int32 >= (1 << i)) + + left_shift = mbits - i + this_mantissa_f32 = (mantissa_f32 << left_shift) & F32_MANTISSA_MASK + this_exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + + result = torch.where(mask, this_exp_biased_f32 | this_mantissa_f32, result) + + # add sign back + sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) + result = result | sign_f32 + + return result.view(torch.float) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index e6d313953e..90c094662c 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -12,7 +12,7 @@ from torch.utils._triton import has_triton from torchao.utils import TORCH_VERSION_AFTER_2_4 -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -21,9 +21,6 @@ from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( - DTYPE_FP4, - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, F32_EXP_BIAS, @@ -152,128 +149,13 @@ def f32_to_f6_e3m2_unpacked(x): return _f32_to_fpx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2) -def _f4_or_f6_unpacked_to_f32(x: torch.Tensor, lp_dtype_name: str): - """ - Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7 - containing an fp4_e2m1 encoding - Output: torch.Tensor of dtype fp32 with the dequantized value - - TODO(future): check if LUT for everything is faster than bit shifting, - especially for fp4. - """ - assert x.dtype == torch.uint8 - - if lp_dtype_name == DTYPE_FP4: - sign_mask = SIGN_MASK_F4 - ebits = EBITS_F4_E2M1 - mbits = MBITS_F4_E2M1 - exp_bias = F4_E2M1_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F4 - elif lp_dtype_name == DTYPE_FP6_E2M3: - sign_mask = SIGN_MASK_F6_E2M3 - ebits = EBITS_F6_E2M3 - mbits = MBITS_F6_E2M3 - exp_bias = F6_E2M3_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F6_E2M3 - elif lp_dtype_name == DTYPE_FP6_E3M2: - sign_mask = SIGN_MASK_F6_E3M2 - ebits = EBITS_F6_E3M2 - mbits = MBITS_F6_E3M2 - exp_bias = F6_E3M2_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F6_E3M2 - else: - raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") - - # save the sign - sign_lp = x & sign_mask - - # set everything to positive, will add sign back at the end - x_pos = x ^ sign_lp - - # - # 1. Calculate zero mask - # - zero_mask = x_pos == 0 - - # - # 2. Calculate the denormal path mask - # - denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) - - # - # 3. Calculate the normal path - # - - # calculate the new exponent and shift it to bits 2:9 of the result - exp_biased_lp = x_pos >> mbits - exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS - exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 - - # shift the mantissa to bits 10:32 of the result - mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) - mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) - result = exp_biased_f32 | mantissa_f32 - - # - # 4. Add the zero and denormal casts to the already casted normal path - # - result[zero_mask] = ZERO_BITS_F32 - # Note: for now the denormal path cast is written for readability and - # numerical correctness. There is likely a way to optimize the performance, - # I just haven't had time to look into it. - if lp_dtype_name == DTYPE_FP4: - result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 - - elif lp_dtype_name == DTYPE_FP6_E2M3: - # Only 7 possible values, just do a LUT - # Note: calculate the booleans first because we are modifying - # this variable inplace. - is_val1 = mantissa_lp_int32 == 1 - is_val2 = mantissa_lp_int32 == 2 - is_val3 = mantissa_lp_int32 == 3 - is_val4 = mantissa_lp_int32 == 4 - is_val5 = mantissa_lp_int32 == 5 - is_val6 = mantissa_lp_int32 == 6 - is_val7 = mantissa_lp_int32 == 7 - mantissa_lp_int32[is_val1] = 0x3E000000 # 0.125 - mantissa_lp_int32[is_val2] = 0x3E800000 # 0.25 - mantissa_lp_int32[is_val3] = 0x3EC00000 # 0.375 - mantissa_lp_int32[is_val4] = 0x3F000000 # 0.5 - mantissa_lp_int32[is_val5] = 0x3F200000 # 0.625 - mantissa_lp_int32[is_val6] = 0x3F400000 # 0.75 - mantissa_lp_int32[is_val7] = 0x3F600000 # 0.875 - result = torch.where(denormal_mask, mantissa_lp_int32, result) - - elif lp_dtype_name == DTYPE_FP6_E3M2: - # Only 3 possible values, just do a LUT - # Note: calculate the booleans first because we are modifying - # this variable inplace. - is_val1 = mantissa_lp_int32 == 1 - is_val2 = mantissa_lp_int32 == 2 - is_val3 = mantissa_lp_int32 == 3 - mantissa_lp_int32[is_val1] = 0x3D800000 # 0.0625 - mantissa_lp_int32[is_val2] = 0x3E000000 # 0.125 - mantissa_lp_int32[is_val3] = 0x3E400000 # 0.1875 - result = torch.where(denormal_mask, mantissa_lp_int32, result) - else: - raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") - - # add sign back - sign_f32 = sign_lp.to(torch.int32) << ( - MBITS_F32 - mbits + EBITS_F32 - ebits - ) # noqa: E501 - result = result | sign_f32 - - return result.view(torch.float) - - def f4_unpacked_to_f32(x: torch.Tensor): """ Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7 containing an fp4_e2m1 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP4) + return _fpx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f6_e2m3_unpacked_to_f32(x: torch.Tensor): @@ -282,7 +164,7 @@ def f6_e2m3_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E2M3) + return _fpx_unpacked_to_f32(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f6_e3m2_unpacked_to_f32(x: torch.Tensor): @@ -291,7 +173,7 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E3M2) + return _fpx_unpacked_to_f32(x, EBITS_F6_E3M2, MBITS_F6_E3M2) if has_triton(): From 334574027bb8f6e601b99f831d189a7ff66bc882 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 14 Jun 2024 23:38:59 +0800 Subject: [PATCH 3/9] small formating --- torchao/prototype/custom_fp_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index c360415d7f..beb3b395a9 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -67,12 +67,8 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: # rewrite saturate/denorm/norm branches without explicit data dependent # control flow, to be more compiler friendly saturate_mask = x >= max_normal - denormal_mask = torch.logical_and( - torch.logical_not(saturate_mask), x < min_normal - ) # noqa: E501 - normal_mask = torch.logical_not( - torch.logical_or(saturate_mask, denormal_mask) - ) # noqa: E501 + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) # # branch 1: saturate to max val - handled later in the code which combines From 2690b927299ac997d4f0096f88dc174e2eac64ac Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 15 Jun 2024 08:40:54 +0800 Subject: [PATCH 4/9] compile with fullgraph=True --- test/prototype/test_fp6_llm.py | 6 +++--- torchao/prototype/custom_fp_utils.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index 6eddc522ab..9ee3faae4a 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device): x = torch.randn(256, 64, device=device) expected = to_tc_float6_e3m2(x) - actual = torch.compile(to_tc_float6_e3m2)(x) + actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) @@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device): x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device) expected = from_tc_float6_e3m2(x) - actual = torch.compile(from_tc_float6_e3m2)(x) + actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias): x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fp6_linear(x) - actual = torch.compile(fp6_linear)(x) + actual = torch.compile(fp6_linear, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index beb3b395a9..01b91bd1c3 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -49,7 +49,8 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: # reinterpret int32 as float32 in Python # see https://stackoverflow.com/a/34446112/1058521 - denorm_mask_float = struct.unpack("!f", struct.pack("!I", denorm_mask_int))[0] + # denorm_mask_float = struct.unpack("!f", struct.pack("!I", denorm_mask_int))[0] + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) # save the sign # Note that we have torch.uint32, but some ops like cpu bit shifts From 8aa0146f3a83e5b1cdf33352454846419a488c5a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 15 Jun 2024 08:53:13 +0800 Subject: [PATCH 5/9] add fullgraph=true --- torchao/prototype/mx_formats/benchmarks/bench_qdq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py index 6e6e373de3..9edc3087a2 100644 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py @@ -61,8 +61,8 @@ def run(profile_folder: Optional[str] = None): data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) if not use_fp4_custom_triton_dequant_kernel: - quant = torch.compile(MXTensor.to_mx) - dequant = torch.compile(data_lp.to_dtype) + quant = torch.compile(MXTensor.to_mx, fullgraph=True) + dequant = torch.compile(data_lp.to_dtype, fullgraph=True) else: # As of 2024-04, torch.compile didn't work with the # handwritten triton kernel, @@ -141,7 +141,7 @@ def run(profile_folder: Optional[str] = None): torch._dynamo.reset() - print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) + print(tabulate.tabulate(results, headers=headers, floatfmt=".2f", tablefmt="github")) if __name__ == "__main__": From be77632aa9662eaec0cd63a49773ce96989ce174 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 15 Jun 2024 08:54:51 +0800 Subject: [PATCH 6/9] undo --- torchao/prototype/mx_formats/benchmarks/bench_qdq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py index 9edc3087a2..ffd9d0d050 100644 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py @@ -141,7 +141,7 @@ def run(profile_folder: Optional[str] = None): torch._dynamo.reset() - print(tabulate.tabulate(results, headers=headers, floatfmt=".2f", tablefmt="github")) + print(tabulate.tabulate(results, headers=headers, floatfmt=".2f")) if __name__ == "__main__": From 95f45823d22bc7257eeac74677855f2049040d53 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 15 Jun 2024 11:08:20 +0800 Subject: [PATCH 7/9] add another version --- torchao/prototype/custom_fp_utils.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 01b91bd1c3..a46e441fa0 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -171,16 +171,29 @@ def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS - # for denormal numbers, we shift left until mantissa bits overflow - # when doing so, we also need to subtract exponent by the same amount. + # 1st iteration: 1 + # 2nd iteration: 10, 11 + # 3rd iteration: 100, 101, 110, 111 + # and so on for i in range(mbits): - mask = torch.logical_and(denormal_mask, mantissa_lp_int32 >= (1 << i)) + for j in range(1 << i, 1 << (i+1)): + left_shift = mbits - i + mantissa = (j - (1 << i)) << (left_shift + MBITS_F32 - mbits) + exp = (denormal_exp_biased - left_shift) << MBITS_F32 + mantissa_lp_int32[mantissa_lp_int32 == j] = exp | mantissa - left_shift = mbits - i - this_mantissa_f32 = (mantissa_f32 << left_shift) & F32_MANTISSA_MASK - this_exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + result = torch.where(denormal_mask, mantissa_lp_int32, result) - result = torch.where(mask, this_exp_biased_f32 | this_mantissa_f32, result) + # # for denormal numbers, we shift left until mantissa bits overflow + # # when doing so, we also need to subtract exponent by the same amount. + # for i in range(mbits): + # mask = torch.logical_and(denormal_mask, mantissa_lp_int32 >= (1 << i)) + + # left_shift = mbits - i + # this_mantissa_f32 = (mantissa_f32 << left_shift) & F32_MANTISSA_MASK + # this_exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + + # result = torch.where(mask, this_exp_biased_f32 | this_mantissa_f32, result) # add sign back sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) From dcd5a05a77fe6d4cb6464bcd58930a2e33978d91 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Sat, 15 Jun 2024 15:27:27 +0800 Subject: [PATCH 8/9] fast path for mbits=1 --- torchao/prototype/custom_fp_utils.py | 62 ++++++++--------- torchao/prototype/mx_formats/custom_cast.py | 74 --------------------- 2 files changed, 31 insertions(+), 105 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index a46e441fa0..646ead6097 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -4,7 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import struct +# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). +# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: +# 1. No encodings are reserved for special values (+/-inf, NaN). +# 2. When downcasting from FP32 to FPx, +# - Rounding mode is round to nearest, ties to even. +# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx +# magnitude (sign is preserved). import torch from torch import Tensor @@ -16,7 +22,6 @@ def _n_ones(n: int) -> int: EBITS_F32, MBITS_F32 = 8, 23 F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) -F32_MANTISSA_MASK = _n_ones(MBITS_F32) def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: @@ -47,9 +52,7 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: ) denorm_mask_int = denorm_exp << MBITS_F32 - # reinterpret int32 as float32 in Python - # see https://stackoverflow.com/a/34446112/1058521 - # denorm_mask_float = struct.unpack("!f", struct.pack("!I", denorm_mask_int))[0] + # reinterpret int32 as float32 denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) # save the sign @@ -165,35 +168,32 @@ def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: # 4. Add the zero and denormal casts to the already casted normal path # result[zero_mask] = 0 - # Note: for now the denormal path cast is written for readability and - # numerical correctness. There is likely a way to optimize the performance, - # I just haven't had time to look into it. denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS - # 1st iteration: 1 - # 2nd iteration: 10, 11 - # 3rd iteration: 100, 101, 110, 111 - # and so on - for i in range(mbits): - for j in range(1 << i, 1 << (i+1)): - left_shift = mbits - i - mantissa = (j - (1 << i)) << (left_shift + MBITS_F32 - mbits) - exp = (denormal_exp_biased - left_shift) << MBITS_F32 - mantissa_lp_int32[mantissa_lp_int32 == j] = exp | mantissa - - result = torch.where(denormal_mask, mantissa_lp_int32, result) - - # # for denormal numbers, we shift left until mantissa bits overflow - # # when doing so, we also need to subtract exponent by the same amount. - # for i in range(mbits): - # mask = torch.logical_and(denormal_mask, mantissa_lp_int32 >= (1 << i)) - - # left_shift = mbits - i - # this_mantissa_f32 = (mantissa_f32 << left_shift) & F32_MANTISSA_MASK - # this_exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 - - # result = torch.where(mask, this_exp_biased_f32 | this_mantissa_f32, result) + # fast path. + # without this, performance for FP4_E2M1 is slower by 2x + if mbits == 1: + result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 + + else: + # iterate over all possible values of mantissa + # i=0, j=1 + # i=1, j=10,11 + # i=2, j=100,101,110,111 + # and so on + for i in range(mbits): + for mantissa_cmp in range(1 << i, 1 << (i+1)): + # left shift mantissa until it overflows (create an implicit 1) + # subtract exponent by the same amount + left_shift = mbits - i + mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits) + exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + + # we can update this in-place since the values won't overlap + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 + + result = torch.where(denormal_mask, mantissa_lp_int32, result) # add sign back sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 90c094662c..f14deba9f6 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -4,8 +4,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import struct - import numpy as np import torch @@ -25,8 +23,6 @@ E8M0_EXPONENT_NAN_VAL, F32_EXP_BIAS, F4_E2M1_EXP_BIAS, - F6_E2M3_EXP_BIAS, - F6_E3M2_EXP_BIAS, ) @@ -45,78 +41,8 @@ def get_bits(x: torch.Tensor) -> str: EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 -DENORM_F32TOF4_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F4_E2M1_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F4_E2M1) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF4_MASK_INT = DENORM_F32TOF4_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF4_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF4_MASK_INT) -)[0] - -DENORM_F32TOF6_E2M3_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F6_E2M3_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F6_E2M3) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF6_E2M3_MASK_INT = DENORM_F32TOF6_E2M3_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF6_E2M3_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF6_E2M3_MASK_INT) -)[0] - -DENORM_F32TOF6_E3M2_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F6_E3M2_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F6_E3M2) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF6_E3M2_MASK_INT = DENORM_F32TOF6_E3M2_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF6_E3M2_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF6_E3M2_MASK_INT) -)[0] - -# -# magic value to add during the normal path -# TODO document this better -# - -# c++ code e5m2: -# f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; -# 0xFFFFF is 1111 1111 1111 1111 1111, 20 ones, 20 = 23 - 3 = 23 - 2 - 1 - -# c++ code e4m3: -# f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; -# 0x7FFFF is 0111 1111 1111 1111 1111, 19 ones, 19 = 23 - 4 = 23 - 3 - 1 - -MAGIC_ADDER_F4_E2M1 = 0x1FFFFF # 21 ones -MAGIC_ADDER_F6_E2M3 = 0x7FFFF # 19 ones -MAGIC_ADDER_F6_E3M2 = 0xFFFFF # 20 ones - -# c++ code named vars -# f_bits += ((uint32_t)(f8_exp_bias - f32_exp_bias) << f32_mbits) + MAGIC_ADDER; # noqa: E501 - SIGN_MASK_F4 = 0x8 # 1000 -SIGN_MASK_F6_E2M3 = 0x20 # 100000 -SIGN_MASK_F6_E3M2 = 0x20 # 100000 - MANTISSA_MASK_F4 = 0x1 # 0001 -MANTISSA_MASK_F6_E2M3 = 0x7 # 000111 -MANTISSA_MASK_F6_E3M2 = 0x3 # 000011 ZERO_BITS_F32 = 0x0 ZERO_POINT_FIVE_BITS_F32 = 0x3F000000 From bd64efcbd8b944ccb04a5bb6b2363b49c040de71 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 17 Jun 2024 21:57:01 +0800 Subject: [PATCH 9/9] add back docstring --- torchao/prototype/custom_fp_utils.py | 31 +++++++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py index 646ead6097..1a3e9e34cb 100644 --- a/torchao/prototype/custom_fp_utils.py +++ b/torchao/prototype/custom_fp_utils.py @@ -25,6 +25,24 @@ def _n_ones(n: int) -> int: def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert FP32 numbers to sub-byte floating point numbers with the given + number of exponent and mantissa bits. + + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + + Note: there are no special values (NaN, inf) support in this code. Values + outside the representable range of FPx after rounding are clamped to the + maximum FPx magnitude (sign is preserved). + + Code below is an adaptation of https://fburl.com/code/ciwofcg4 + + Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 + """ assert x.dtype == torch.float assert 1 + ebits + mbits <= 8 @@ -122,10 +140,17 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: return x.to(torch.uint8) +# TODO(future): check if LUT for everything is faster than bit shifting, +# especially for fp4 (only 2^4=16 unique values). def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: - """ - TODO(future): check if LUT for everything is faster than bit shifting, - especially for fp4. + """Convert sub-byte floating point numbers with the given number of exponent + and mantissa bits to FP32. + + Input: torch.Tensor of dtype uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value """ assert x.dtype == torch.uint8 assert 1 + ebits + mbits <= 8