diff --git a/test/prototype/test_fp6_llm.py b/test/prototype/test_fp6_llm.py index 6eddc522ab..9ee3faae4a 100644 --- a/test/prototype/test_fp6_llm.py +++ b/test/prototype/test_fp6_llm.py @@ -34,7 +34,7 @@ def test_to_tc_float6_e3m2_compile(self, device): x = torch.randn(256, 64, device=device) expected = to_tc_float6_e3m2(x) - actual = torch.compile(to_tc_float6_e3m2)(x) + actual = torch.compile(to_tc_float6_e3m2, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @parametrize("device", _DEVICES) @@ -53,7 +53,7 @@ def test_from_tc_float6_e3m2_compile(self, device): x = torch.randint(256, size=(M, N * 3 // 4), dtype=torch.uint8, device=device) expected = from_tc_float6_e3m2(x) - actual = torch.compile(from_tc_float6_e3m2)(x) + actual = torch.compile(from_tc_float6_e3m2, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -81,7 +81,7 @@ def test_fp6_llm_linear_compile(self, bias): x = torch.randn(N, IC, device=device, dtype=torch.half) expected = fp6_linear(x) - actual = torch.compile(fp6_linear)(x) + actual = torch.compile(fp6_linear, fullgraph=True)(x) torch.testing.assert_close(actual, expected) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/prototype/custom_fp_utils.py b/torchao/prototype/custom_fp_utils.py new file mode 100644 index 0000000000..1a3e9e34cb --- /dev/null +++ b/torchao/prototype/custom_fp_utils.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). +# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: +# 1. No encodings are reserved for special values (+/-inf, NaN). +# 2. When downcasting from FP32 to FPx, +# - Rounding mode is round to nearest, ties to even. +# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx +# magnitude (sign is preserved). + +import torch +from torch import Tensor + + +def _n_ones(n: int) -> int: + return (1 << n) - 1 + + +EBITS_F32, MBITS_F32 = 8, 23 +F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) + + +def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert FP32 numbers to sub-byte floating point numbers with the given + number of exponent and mantissa bits. + + Input: torch.Tensor of dtype torch.float + Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + + Note: there are no special values (NaN, inf) support in this code. Values + outside the representable range of FPx after rounding are clamped to the + maximum FPx magnitude (sign is preserved). + + Code below is an adaptation of https://fburl.com/code/ciwofcg4 + + Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 + Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 + """ + assert x.dtype == torch.float + assert 1 + ebits + mbits <= 8 + + # calculate constants + exp_bias = _n_ones(ebits - 1) + max_int = _n_ones(ebits + mbits) + sign_mask = 1 << (ebits + mbits) + + # TODO document this better + magic_adder = _n_ones(MBITS_F32 - mbits - 1) + + # all E bits and M bits are 1s + max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) + + # E bits = 1, M bits = 0 + min_normal = 2 ** (1 - exp_bias) + + denorm_exp = ( + # exp bias conversion between formats + (F32_EXP_BIAS - exp_bias) + # mantissa length difference between formats + + (MBITS_F32 - mbits) + # add one to encoded exponent for denormalized numbers + + 1 + ) + denorm_mask_int = denorm_exp << MBITS_F32 + + # reinterpret int32 as float32 + denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) + + # save the sign + # Note that we have torch.uint32, but some ops like cpu bit shifts + # do not work on it. So, we stay in int32. + x = x.view(torch.int32) + sign = x & 0x80000000 + + # set everything to positive, will add sign back at the end + x = x ^ sign + + # TODO: can the branch floating point comparisons below be done without + # converting to float? probably but need to verify + x = x.view(torch.float) + + # rewrite saturate/denorm/norm branches without explicit data dependent + # control flow, to be more compiler friendly + saturate_mask = x >= max_normal + denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) + normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) + + # + # branch 1: saturate to max val - handled later in the code which combines + # the branches + # + + # + # branch 2: to conversion to denormal as well as rounding up to normal + # + denormal_x = x + denorm_mask_float + denormal_x = denormal_x.view(torch.int32) + denormal_x -= denorm_mask_int + denormal_x = denormal_x.to(torch.uint8) + + # + # branch 3: stay in normal range, adjust the exponent and round + # + normal_x = x.view(torch.int32) + # resulting mantissa is odd + mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 + # update exponent, rounding bias part 1 + val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder + normal_x += val_to_add + # rounding bias part 2 + normal_x += mant_odd + # take the bits! + normal_x = normal_x >> (MBITS_F32 - mbits) + normal_x = normal_x.to(torch.uint8) + + # + # combine the branches + # + x = torch.full_like(x, max_int, dtype=torch.uint8) + x = torch.where(denormal_mask, denormal_x, x) + x = torch.where(normal_mask, normal_x, x) + + # add sign back + sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) + sign_lp = sign_lp.to(torch.uint8) + # Right shift of a negative signed integer can fill the least significant + # bits with either 1s or 0s, depending on the implementation. Since PyTorch + # doesn't have an uint32 dtype, we mask out these bits to get just the + # f4 sign bit + sign_lp = sign_lp & sign_mask + x = x | sign_lp + + return x.to(torch.uint8) + + +# TODO(future): check if LUT for everything is faster than bit shifting, +# especially for fp4 (only 2^4=16 unique values). +def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: + """Convert sub-byte floating point numbers with the given number of exponent + and mantissa bits to FP32. + + Input: torch.Tensor of dtype uint8, where the bit encoding is stored + in the least significant bits. e.g. + fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding + fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding + Output: torch.Tensor of dtype fp32 with the dequantized value + """ + assert x.dtype == torch.uint8 + assert 1 + ebits + mbits <= 8 + + sign_mask = 1 << (ebits + mbits) + exp_bias = _n_ones(ebits - 1) + mantissa_mask = _n_ones(mbits) + + # save the sign + sign_lp = x & sign_mask + + # set everything to positive, will add sign back at the end + x_pos = x ^ sign_lp + + # + # 1. Calculate zero mask + # + zero_mask = x_pos == 0 + + # + # 2. Calculate the denormal path mask + # + denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) + + # + # 3. Calculate the normal path + # + + # calculate the new exponent and shift it to bits 2:9 of the result + exp_biased_lp = x_pos >> mbits + exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS + exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 + + # shift the mantissa to bits 10:32 of the result + mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) + mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) + result = exp_biased_f32 | mantissa_f32 + + # + # 4. Add the zero and denormal casts to the already casted normal path + # + result[zero_mask] = 0 + + denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS + + # fast path. + # without this, performance for FP4_E2M1 is slower by 2x + if mbits == 1: + result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 + + else: + # iterate over all possible values of mantissa + # i=0, j=1 + # i=1, j=10,11 + # i=2, j=100,101,110,111 + # and so on + for i in range(mbits): + for mantissa_cmp in range(1 << i, 1 << (i+1)): + # left shift mantissa until it overflows (create an implicit 1) + # subtract exponent by the same amount + left_shift = mbits - i + mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits) + exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 + + # we can update this in-place since the values won't overlap + mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 + + result = torch.where(denormal_mask, mantissa_lp_int32, result) + + # add sign back + sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) + result = result | sign_f32 + + return result.view(torch.float) diff --git a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py index 6e6e373de3..ffd9d0d050 100644 --- a/torchao/prototype/mx_formats/benchmarks/bench_qdq.py +++ b/torchao/prototype/mx_formats/benchmarks/bench_qdq.py @@ -61,8 +61,8 @@ def run(profile_folder: Optional[str] = None): data_lp = MXTensor.to_mx(data_hp, elem_dtype, block_size=32) if not use_fp4_custom_triton_dequant_kernel: - quant = torch.compile(MXTensor.to_mx) - dequant = torch.compile(data_lp.to_dtype) + quant = torch.compile(MXTensor.to_mx, fullgraph=True) + dequant = torch.compile(data_lp.to_dtype, fullgraph=True) else: # As of 2024-04, torch.compile didn't work with the # handwritten triton kernel, diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 91aea9275a..f14deba9f6 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -4,14 +4,13 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -import struct - import numpy as np import torch from torch.utils._triton import has_triton from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -20,24 +19,10 @@ from torch._inductor.runtime.triton_helpers import libdevice from torchao.prototype.mx_formats.constants import ( - DTYPE_FP4, - DTYPE_FP6_E2M3, - DTYPE_FP6_E3M2, E8M0_EXPONENT_BIAS, E8M0_EXPONENT_NAN_VAL, F32_EXP_BIAS, F4_E2M1_EXP_BIAS, - F4_E2M1_MAX, - F4_E2M1_MAX_INT, - F4_E2M1_MIN_NORMAL, - F6_E2M3_EXP_BIAS, - F6_E2M3_MAX, - F6_E2M3_MAX_INT, - F6_E2M3_MIN_NORMAL, - F6_E3M2_EXP_BIAS, - F6_E3M2_MAX, - F6_E3M2_MAX_INT, - F6_E3M2_MIN_NORMAL, ) @@ -56,202 +41,20 @@ def get_bits(x: torch.Tensor) -> str: EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3 EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2 -DENORM_F32TOF4_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F4_E2M1_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F4_E2M1) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF4_MASK_INT = DENORM_F32TOF4_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF4_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF4_MASK_INT) -)[0] - -DENORM_F32TOF6_E2M3_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F6_E2M3_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F6_E2M3) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF6_E2M3_MASK_INT = DENORM_F32TOF6_E2M3_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF6_E2M3_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF6_E2M3_MASK_INT) -)[0] - -DENORM_F32TOF6_E3M2_EXP = ( - # exp bias conversion between formats - (F32_EXP_BIAS - F6_E3M2_EXP_BIAS) - # mantissa length difference between formats - + (MBITS_F32 - MBITS_F6_E3M2) - # add one to encoded exponent for denormalized numbers - + 1 -) -DENORM_F32TOF6_E3M2_MASK_INT = DENORM_F32TOF6_E3M2_EXP << MBITS_F32 -# reinterpret int32 as float32 in Python -# see https://stackoverflow.com/a/34446112/1058521 -DENORM_F32TOF6_E3M2_MASK_FLOAT = struct.unpack( - "!f", struct.pack("!I", DENORM_F32TOF6_E3M2_MASK_INT) -)[0] - -# -# magic value to add during the normal path -# TODO document this better -# - -# c++ code e5m2: -# f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF; -# 0xFFFFF is 1111 1111 1111 1111 1111, 20 ones, 20 = 23 - 3 = 23 - 2 - 1 - -# c++ code e4m3: -# f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF; -# 0x7FFFF is 0111 1111 1111 1111 1111, 19 ones, 19 = 23 - 4 = 23 - 3 - 1 - -MAGIC_ADDER_F4_E2M1 = 0x1FFFFF # 21 ones -MAGIC_ADDER_F6_E2M3 = 0x7FFFF # 19 ones -MAGIC_ADDER_F6_E3M2 = 0xFFFFF # 20 ones - -# c++ code named vars -# f_bits += ((uint32_t)(f8_exp_bias - f32_exp_bias) << f32_mbits) + MAGIC_ADDER; # noqa: E501 - SIGN_MASK_F4 = 0x8 # 1000 -SIGN_MASK_F6_E2M3 = 0x20 # 100000 -SIGN_MASK_F6_E3M2 = 0x20 # 100000 - MANTISSA_MASK_F4 = 0x1 # 0001 -MANTISSA_MASK_F6_E2M3 = 0x7 # 000111 -MANTISSA_MASK_F6_E3M2 = 0x3 # 000011 ZERO_BITS_F32 = 0x0 ZERO_POINT_FIVE_BITS_F32 = 0x3F000000 -def _f32_to_f4_or_f6_unpacked( - x, - max_normal, - min_normal, - denorm_mask_float, - denorm_mask_int, - ebits, - mbits, - exp_bias, - magic_adder, - max_int, - sign_mask, -): - """ - Input: torch.Tensor of dtype torch.float - Output: torch.Tensor of dtype torch.uint8, - fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding - fp6: bits 0-1 empty and bits 2-7 in the fp6_e2m3 or fp6_e3m2 encoding - - Note: there is no special values (NaN, inf) support in this code as the - OCP spec does not define special values for fp6 and fp4 dtypes. - - Code below is an adaptation of https://fburl.com/code/ciwofcg4 for f4/f6 - - Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 - Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 - """ - assert x.dtype == torch.float - - # save the sign - # Note that we have torch.uint32, but some ops like cpu bit shifts - # do not work on it. So, we stay in int32. - x = x.view(torch.int32) - sign = x & 0x80000000 - - # set everything to positive, will add sign back at the end - x = x ^ sign - - # TODO: can the branch floating point comparisons below be done without - # converting to float? probably but need to verify - x = x.view(torch.float) - - # rewrite saturate/denorm/norm branches without explicit data dependent - # control flow, to be more compiler friendly - saturate_mask = x >= max_normal - denormal_mask = torch.logical_and( - torch.logical_not(saturate_mask), x < min_normal - ) # noqa: E501 - normal_mask = torch.logical_not( - torch.logical_or(saturate_mask, denormal_mask) - ) # noqa: E501 - - # - # branch 1: saturate to max val - handled later in the code which combines - # the branches - # - - # - # branch 2: to conversion to denormal as well as rounding up to normal - # - denormal_x = x + denorm_mask_float - denormal_x = denormal_x.view(torch.int32) - denormal_x -= denorm_mask_int - denormal_x = denormal_x.to(torch.uint8) - - # - # branch 3: stay in normal range, adjust the exponent and round - # - normal_x = x.view(torch.int32) - # resulting mantissa is odd - mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 - # update exponent, rounding bias part 1 - val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder - normal_x += val_to_add - # rounding bias part 2 - normal_x += mant_odd - # take the bits! - normal_x = normal_x >> (MBITS_F32 - mbits) - normal_x = normal_x.to(torch.uint8) - - # - # combine the branches - # - x = torch.full_like(x, max_int, dtype=torch.uint8) - x = torch.where(denormal_mask, denormal_x, x) - x = torch.where(normal_mask, normal_x, x) - - # add sign back - sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) - sign_lp = sign_lp.to(torch.uint8) - # Right shift of a negative signed integer can fill the least significant - # bits with either 1s or 0s, depending on the implementation. Since PyTorch - # doesn't have an uint32 dtype, we mask out these bits to get just the - # f4 sign bit - sign_lp = sign_lp & sign_mask - x = x | sign_lp - - return x.to(torch.uint8) - - def f32_to_f4_unpacked(x): """ Input: torch.Tensor of dtype torch.float Output: torch.Tensor of dtype torch.uint8, with bits 0-3 empty and bits 4-7 in fp4_e2m1 """ - return _f32_to_f4_or_f6_unpacked( - x, - F4_E2M1_MAX, - F4_E2M1_MIN_NORMAL, - DENORM_F32TOF4_MASK_FLOAT, - DENORM_F32TOF4_MASK_INT, - EBITS_F4_E2M1, - MBITS_F4_E2M1, - F4_E2M1_EXP_BIAS, - MAGIC_ADDER_F4_E2M1, - F4_E2M1_MAX_INT, - SIGN_MASK_F4, - ) + return _f32_to_fpx_unpacked(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f32_to_f6_e2m3_unpacked(x): @@ -260,19 +63,7 @@ def f32_to_f6_e2m3_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e2m3 """ - return _f32_to_f4_or_f6_unpacked( - x, - F6_E2M3_MAX, - F6_E2M3_MIN_NORMAL, - DENORM_F32TOF6_E2M3_MASK_FLOAT, - DENORM_F32TOF6_E2M3_MASK_INT, - EBITS_F6_E2M3, - MBITS_F6_E2M3, - F6_E2M3_EXP_BIAS, - MAGIC_ADDER_F6_E2M3, - F6_E2M3_MAX_INT, - SIGN_MASK_F6_E2M3, - ) + return _f32_to_fpx_unpacked(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f32_to_f6_e3m2_unpacked(x): @@ -281,134 +72,7 @@ def f32_to_f6_e3m2_unpacked(x): Output: torch.Tensor of dtype torch.uint8, with bits 0-1 empty and bits 2-7 in fp6_e3m2 """ - return _f32_to_f4_or_f6_unpacked( - x, - F6_E3M2_MAX, - F6_E3M2_MIN_NORMAL, - DENORM_F32TOF6_E3M2_MASK_FLOAT, - DENORM_F32TOF6_E3M2_MASK_INT, - EBITS_F6_E3M2, - MBITS_F6_E3M2, - F6_E3M2_EXP_BIAS, - MAGIC_ADDER_F6_E3M2, - F6_E3M2_MAX_INT, - SIGN_MASK_F6_E3M2, - ) - - -def _f4_or_f6_unpacked_to_f32(x: torch.Tensor, lp_dtype_name: str): - """ - Input: torch.Tensor of dtype uint8, with bits 0-3 empty and bits 4-7 - containing an fp4_e2m1 encoding - Output: torch.Tensor of dtype fp32 with the dequantized value - - TODO(future): check if LUT for everything is faster than bit shifting, - especially for fp4. - """ - assert x.dtype == torch.uint8 - - if lp_dtype_name == DTYPE_FP4: - sign_mask = SIGN_MASK_F4 - ebits = EBITS_F4_E2M1 - mbits = MBITS_F4_E2M1 - exp_bias = F4_E2M1_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F4 - elif lp_dtype_name == DTYPE_FP6_E2M3: - sign_mask = SIGN_MASK_F6_E2M3 - ebits = EBITS_F6_E2M3 - mbits = MBITS_F6_E2M3 - exp_bias = F6_E2M3_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F6_E2M3 - elif lp_dtype_name == DTYPE_FP6_E3M2: - sign_mask = SIGN_MASK_F6_E3M2 - ebits = EBITS_F6_E3M2 - mbits = MBITS_F6_E3M2 - exp_bias = F6_E3M2_EXP_BIAS - mantissa_mask = MANTISSA_MASK_F6_E3M2 - else: - raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") - - # save the sign - sign_lp = x & sign_mask - - # set everything to positive, will add sign back at the end - x_pos = x ^ sign_lp - - # - # 1. Calculate zero mask - # - zero_mask = x_pos == 0 - - # - # 2. Calculate the denormal path mask - # - denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) - - # - # 3. Calculate the normal path - # - - # calculate the new exponent and shift it to bits 2:9 of the result - exp_biased_lp = x_pos >> mbits - exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS - exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 - - # shift the mantissa to bits 10:32 of the result - mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) - mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) - result = exp_biased_f32 | mantissa_f32 - - # - # 4. Add the zero and denormal casts to the already casted normal path - # - result[zero_mask] = ZERO_BITS_F32 - # Note: for now the denormal path cast is written for readability and - # numerical correctness. There is likely a way to optimize the performance, - # I just haven't had time to look into it. - if lp_dtype_name == DTYPE_FP4: - result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32 - - elif lp_dtype_name == DTYPE_FP6_E2M3: - # Only 7 possible values, just do a LUT - # Note: calculate the booleans first because we are modifying - # this variable inplace. - is_val1 = mantissa_lp_int32 == 1 - is_val2 = mantissa_lp_int32 == 2 - is_val3 = mantissa_lp_int32 == 3 - is_val4 = mantissa_lp_int32 == 4 - is_val5 = mantissa_lp_int32 == 5 - is_val6 = mantissa_lp_int32 == 6 - is_val7 = mantissa_lp_int32 == 7 - mantissa_lp_int32[is_val1] = 0x3E000000 # 0.125 - mantissa_lp_int32[is_val2] = 0x3E800000 # 0.25 - mantissa_lp_int32[is_val3] = 0x3EC00000 # 0.375 - mantissa_lp_int32[is_val4] = 0x3F000000 # 0.5 - mantissa_lp_int32[is_val5] = 0x3F200000 # 0.625 - mantissa_lp_int32[is_val6] = 0x3F400000 # 0.75 - mantissa_lp_int32[is_val7] = 0x3F600000 # 0.875 - result = torch.where(denormal_mask, mantissa_lp_int32, result) - - elif lp_dtype_name == DTYPE_FP6_E3M2: - # Only 3 possible values, just do a LUT - # Note: calculate the booleans first because we are modifying - # this variable inplace. - is_val1 = mantissa_lp_int32 == 1 - is_val2 = mantissa_lp_int32 == 2 - is_val3 = mantissa_lp_int32 == 3 - mantissa_lp_int32[is_val1] = 0x3D800000 # 0.0625 - mantissa_lp_int32[is_val2] = 0x3E000000 # 0.125 - mantissa_lp_int32[is_val3] = 0x3E400000 # 0.1875 - result = torch.where(denormal_mask, mantissa_lp_int32, result) - else: - raise AssertionError(f"unsupported lp_dtype_name {lp_dtype_name}") - - # add sign back - sign_f32 = sign_lp.to(torch.int32) << ( - MBITS_F32 - mbits + EBITS_F32 - ebits - ) # noqa: E501 - result = result | sign_f32 - - return result.view(torch.float) + return _f32_to_fpx_unpacked(x, EBITS_F6_E3M2, MBITS_F6_E3M2) def f4_unpacked_to_f32(x: torch.Tensor): @@ -417,7 +81,7 @@ def f4_unpacked_to_f32(x: torch.Tensor): containing an fp4_e2m1 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP4) + return _fpx_unpacked_to_f32(x, EBITS_F4_E2M1, MBITS_F4_E2M1) def f6_e2m3_unpacked_to_f32(x: torch.Tensor): @@ -426,7 +90,7 @@ def f6_e2m3_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E2M3) + return _fpx_unpacked_to_f32(x, EBITS_F6_E2M3, MBITS_F6_E2M3) def f6_e3m2_unpacked_to_f32(x: torch.Tensor): @@ -435,7 +99,7 @@ def f6_e3m2_unpacked_to_f32(x: torch.Tensor): containing an fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ - return _f4_or_f6_unpacked_to_f32(x, DTYPE_FP6_E3M2) + return _fpx_unpacked_to_f32(x, EBITS_F6_E3M2, MBITS_F6_E3M2) if has_triton():