Skip to content

Commit

Permalink
fast path for mbits=1
Browse files Browse the repository at this point in the history
  • Loading branch information
gau-nernst committed Jun 15, 2024
1 parent 95f4582 commit dcd5a05
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 105 deletions.
62 changes: 31 additions & 31 deletions torchao/prototype/custom_fp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import struct
# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3).
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain:
# 1. No encodings are reserved for special values (+/-inf, NaN).
# 2. When downcasting from FP32 to FPx,
# - Rounding mode is round to nearest, ties to even.
# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx
# magnitude (sign is preserved).

import torch
from torch import Tensor
Expand All @@ -16,7 +22,6 @@ def _n_ones(n: int) -> int:

EBITS_F32, MBITS_F32 = 8, 23
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1)
F32_MANTISSA_MASK = _n_ones(MBITS_F32)


def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
Expand Down Expand Up @@ -47,9 +52,7 @@ def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor:
)
denorm_mask_int = denorm_exp << MBITS_F32

# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
# denorm_mask_float = struct.unpack("!f", struct.pack("!I", denorm_mask_int))[0]
# reinterpret int32 as float32
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32)

# save the sign
Expand Down Expand Up @@ -165,35 +168,32 @@ def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor:
# 4. Add the zero and denormal casts to the already casted normal path
#
result[zero_mask] = 0
# Note: for now the denormal path cast is written for readability and
# numerical correctness. There is likely a way to optimize the performance,
# I just haven't had time to look into it.

denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS

# 1st iteration: 1
# 2nd iteration: 10, 11
# 3rd iteration: 100, 101, 110, 111
# and so on
for i in range(mbits):
for j in range(1 << i, 1 << (i+1)):
left_shift = mbits - i
mantissa = (j - (1 << i)) << (left_shift + MBITS_F32 - mbits)
exp = (denormal_exp_biased - left_shift) << MBITS_F32
mantissa_lp_int32[mantissa_lp_int32 == j] = exp | mantissa

result = torch.where(denormal_mask, mantissa_lp_int32, result)

# # for denormal numbers, we shift left until mantissa bits overflow
# # when doing so, we also need to subtract exponent by the same amount.
# for i in range(mbits):
# mask = torch.logical_and(denormal_mask, mantissa_lp_int32 >= (1 << i))

# left_shift = mbits - i
# this_mantissa_f32 = (mantissa_f32 << left_shift) & F32_MANTISSA_MASK
# this_exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32

# result = torch.where(mask, this_exp_biased_f32 | this_mantissa_f32, result)
# fast path.
# without this, performance for FP4_E2M1 is slower by 2x
if mbits == 1:
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32

else:
# iterate over all possible values of mantissa
# i=0, j=1
# i=1, j=10,11
# i=2, j=100,101,110,111
# and so on
for i in range(mbits):
for mantissa_cmp in range(1 << i, 1 << (i+1)):
# left shift mantissa until it overflows (create an implicit 1)
# subtract exponent by the same amount
left_shift = mbits - i
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits)
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32

# we can update this in-place since the values won't overlap
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32

result = torch.where(denormal_mask, mantissa_lp_int32, result)

# add sign back
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits)
Expand Down
74 changes: 0 additions & 74 deletions torchao/prototype/mx_formats/custom_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import struct

import numpy as np

import torch
Expand All @@ -25,8 +23,6 @@
E8M0_EXPONENT_NAN_VAL,
F32_EXP_BIAS,
F4_E2M1_EXP_BIAS,
F6_E2M3_EXP_BIAS,
F6_E3M2_EXP_BIAS,
)


Expand All @@ -45,78 +41,8 @@ def get_bits(x: torch.Tensor) -> str:
EBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3
EBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2

DENORM_F32TOF4_EXP = (
# exp bias conversion between formats
(F32_EXP_BIAS - F4_E2M1_EXP_BIAS)
# mantissa length difference between formats
+ (MBITS_F32 - MBITS_F4_E2M1)
# add one to encoded exponent for denormalized numbers
+ 1
)
DENORM_F32TOF4_MASK_INT = DENORM_F32TOF4_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF4_MASK_FLOAT = struct.unpack(
"!f", struct.pack("!I", DENORM_F32TOF4_MASK_INT)
)[0]

DENORM_F32TOF6_E2M3_EXP = (
# exp bias conversion between formats
(F32_EXP_BIAS - F6_E2M3_EXP_BIAS)
# mantissa length difference between formats
+ (MBITS_F32 - MBITS_F6_E2M3)
# add one to encoded exponent for denormalized numbers
+ 1
)
DENORM_F32TOF6_E2M3_MASK_INT = DENORM_F32TOF6_E2M3_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF6_E2M3_MASK_FLOAT = struct.unpack(
"!f", struct.pack("!I", DENORM_F32TOF6_E2M3_MASK_INT)
)[0]

DENORM_F32TOF6_E3M2_EXP = (
# exp bias conversion between formats
(F32_EXP_BIAS - F6_E3M2_EXP_BIAS)
# mantissa length difference between formats
+ (MBITS_F32 - MBITS_F6_E3M2)
# add one to encoded exponent for denormalized numbers
+ 1
)
DENORM_F32TOF6_E3M2_MASK_INT = DENORM_F32TOF6_E3M2_EXP << MBITS_F32
# reinterpret int32 as float32 in Python
# see https://stackoverflow.com/a/34446112/1058521
DENORM_F32TOF6_E3M2_MASK_FLOAT = struct.unpack(
"!f", struct.pack("!I", DENORM_F32TOF6_E3M2_MASK_INT)
)[0]

#
# magic value to add during the normal path
# TODO document this better
#

# c++ code e5m2:
# f_bits += ((uint32_t)(15 - 127) << 23) + 0xFFFFF;
# 0xFFFFF is 1111 1111 1111 1111 1111, 20 ones, 20 = 23 - 3 = 23 - 2 - 1

# c++ code e4m3:
# f_bits += ((uint32_t)(7 - 127) << 23) + 0x7FFFF;
# 0x7FFFF is 0111 1111 1111 1111 1111, 19 ones, 19 = 23 - 4 = 23 - 3 - 1

MAGIC_ADDER_F4_E2M1 = 0x1FFFFF # 21 ones
MAGIC_ADDER_F6_E2M3 = 0x7FFFF # 19 ones
MAGIC_ADDER_F6_E3M2 = 0xFFFFF # 20 ones

# c++ code named vars
# f_bits += ((uint32_t)(f8_exp_bias - f32_exp_bias) << f32_mbits) + MAGIC_ADDER; # noqa: E501

SIGN_MASK_F4 = 0x8 # 1000
SIGN_MASK_F6_E2M3 = 0x20 # 100000
SIGN_MASK_F6_E3M2 = 0x20 # 100000

MANTISSA_MASK_F4 = 0x1 # 0001
MANTISSA_MASK_F6_E2M3 = 0x7 # 000111
MANTISSA_MASK_F6_E3M2 = 0x3 # 000011

ZERO_BITS_F32 = 0x0
ZERO_POINT_FIVE_BITS_F32 = 0x3F000000
Expand Down

0 comments on commit dcd5a05

Please sign in to comment.