-
Notifications
You must be signed in to change notification settings - Fork 185
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor custom FPx cast #363
Merged
Merged
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
47f7bc1
refactor custom fp cast
gau-nernst da17611
add dequant
gau-nernst 3345740
small formating
gau-nernst 2690b92
compile with fullgraph=True
gau-nernst 8aa0146
add fullgraph=true
gau-nernst be77632
undo
gau-nernst 95f4582
add another version
gau-nernst dcd5a05
fast path for mbits=1
gau-nernst f61ff05
Merge branch 'pytorch:main' into custom_fpx
gau-nernst 4ad065f
Merge branch 'pytorch:main' into custom_fpx
gau-nernst bd64efc
add back docstring
gau-nernst File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# This script was initially developed for sub-byte MX dtypes (FP4 E2M1, FP6 E3M2, and FP6 E2M3). | ||
# It has been refactored to support any sub-byte FP dtypes. However, some behaviors of MX dtypes remain: | ||
# 1. No encodings are reserved for special values (+/-inf, NaN). | ||
# 2. When downcasting from FP32 to FPx, | ||
# - Rounding mode is round to nearest, ties to even. | ||
# - Values outside the representable range of FPx after rounding are clamped to the maximum FPx | ||
# magnitude (sign is preserved). | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
|
||
def _n_ones(n: int) -> int: | ||
return (1 << n) - 1 | ||
|
||
|
||
EBITS_F32, MBITS_F32 = 8, 23 | ||
F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) | ||
|
||
|
||
def _f32_to_fpx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: | ||
"""Convert FP32 numbers to sub-byte floating point numbers with the given | ||
number of exponent and mantissa bits. | ||
|
||
Input: torch.Tensor of dtype torch.float | ||
Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored | ||
in the least significant bits. e.g. | ||
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding | ||
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding | ||
|
||
Note: there are no special values (NaN, inf) support in this code. Values | ||
outside the representable range of FPx after rounding are clamped to the | ||
maximum FPx magnitude (sign is preserved). | ||
|
||
Code below is an adaptation of https://fburl.com/code/ciwofcg4 | ||
|
||
Background 1: last answer in https://stackoverflow.com/questions/8981913/how-to-perform-round-to-even-with-floating-point-numbers # noqa: E501 | ||
Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 | ||
""" | ||
assert x.dtype == torch.float | ||
assert 1 + ebits + mbits <= 8 | ||
|
||
# calculate constants | ||
exp_bias = _n_ones(ebits - 1) | ||
max_int = _n_ones(ebits + mbits) | ||
sign_mask = 1 << (ebits + mbits) | ||
|
||
# TODO document this better | ||
magic_adder = _n_ones(MBITS_F32 - mbits - 1) | ||
|
||
# all E bits and M bits are 1s | ||
max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits)) | ||
|
||
# E bits = 1, M bits = 0 | ||
min_normal = 2 ** (1 - exp_bias) | ||
|
||
denorm_exp = ( | ||
# exp bias conversion between formats | ||
(F32_EXP_BIAS - exp_bias) | ||
# mantissa length difference between formats | ||
+ (MBITS_F32 - mbits) | ||
# add one to encoded exponent for denormalized numbers | ||
+ 1 | ||
) | ||
denorm_mask_int = denorm_exp << MBITS_F32 | ||
|
||
# reinterpret int32 as float32 | ||
denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view(torch.float32) | ||
|
||
# save the sign | ||
# Note that we have torch.uint32, but some ops like cpu bit shifts | ||
# do not work on it. So, we stay in int32. | ||
x = x.view(torch.int32) | ||
sign = x & 0x80000000 | ||
|
||
# set everything to positive, will add sign back at the end | ||
x = x ^ sign | ||
|
||
# TODO: can the branch floating point comparisons below be done without | ||
# converting to float? probably but need to verify | ||
x = x.view(torch.float) | ||
|
||
# rewrite saturate/denorm/norm branches without explicit data dependent | ||
# control flow, to be more compiler friendly | ||
saturate_mask = x >= max_normal | ||
denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) | ||
normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) | ||
|
||
# | ||
# branch 1: saturate to max val - handled later in the code which combines | ||
# the branches | ||
# | ||
|
||
# | ||
# branch 2: to conversion to denormal as well as rounding up to normal | ||
# | ||
denormal_x = x + denorm_mask_float | ||
denormal_x = denormal_x.view(torch.int32) | ||
denormal_x -= denorm_mask_int | ||
denormal_x = denormal_x.to(torch.uint8) | ||
|
||
# | ||
# branch 3: stay in normal range, adjust the exponent and round | ||
# | ||
normal_x = x.view(torch.int32) | ||
# resulting mantissa is odd | ||
mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 | ||
# update exponent, rounding bias part 1 | ||
val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder | ||
normal_x += val_to_add | ||
# rounding bias part 2 | ||
normal_x += mant_odd | ||
# take the bits! | ||
normal_x = normal_x >> (MBITS_F32 - mbits) | ||
normal_x = normal_x.to(torch.uint8) | ||
|
||
# | ||
# combine the branches | ||
# | ||
x = torch.full_like(x, max_int, dtype=torch.uint8) | ||
x = torch.where(denormal_mask, denormal_x, x) | ||
x = torch.where(normal_mask, normal_x, x) | ||
|
||
# add sign back | ||
sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) | ||
sign_lp = sign_lp.to(torch.uint8) | ||
# Right shift of a negative signed integer can fill the least significant | ||
# bits with either 1s or 0s, depending on the implementation. Since PyTorch | ||
# doesn't have an uint32 dtype, we mask out these bits to get just the | ||
# f4 sign bit | ||
sign_lp = sign_lp & sign_mask | ||
x = x | sign_lp | ||
|
||
return x.to(torch.uint8) | ||
|
||
|
||
# TODO(future): check if LUT for everything is faster than bit shifting, | ||
# especially for fp4 (only 2^4=16 unique values). | ||
def _fpx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: | ||
"""Convert sub-byte floating point numbers with the given number of exponent | ||
and mantissa bits to FP32. | ||
|
||
Input: torch.Tensor of dtype uint8, where the bit encoding is stored | ||
in the least significant bits. e.g. | ||
fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding | ||
fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding | ||
Output: torch.Tensor of dtype fp32 with the dequantized value | ||
""" | ||
assert x.dtype == torch.uint8 | ||
assert 1 + ebits + mbits <= 8 | ||
|
||
sign_mask = 1 << (ebits + mbits) | ||
exp_bias = _n_ones(ebits - 1) | ||
mantissa_mask = _n_ones(mbits) | ||
|
||
# save the sign | ||
sign_lp = x & sign_mask | ||
|
||
# set everything to positive, will add sign back at the end | ||
x_pos = x ^ sign_lp | ||
|
||
# | ||
# 1. Calculate zero mask | ||
# | ||
zero_mask = x_pos == 0 | ||
|
||
# | ||
# 2. Calculate the denormal path mask | ||
# | ||
denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) | ||
|
||
# | ||
# 3. Calculate the normal path | ||
# | ||
|
||
# calculate the new exponent and shift it to bits 2:9 of the result | ||
exp_biased_lp = x_pos >> mbits | ||
exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS | ||
exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 | ||
|
||
# shift the mantissa to bits 10:32 of the result | ||
mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) | ||
mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) | ||
result = exp_biased_f32 | mantissa_f32 | ||
|
||
# | ||
# 4. Add the zero and denormal casts to the already casted normal path | ||
# | ||
result[zero_mask] = 0 | ||
|
||
denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS | ||
|
||
# fast path. | ||
# without this, performance for FP4_E2M1 is slower by 2x | ||
if mbits == 1: | ||
result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 | ||
|
||
else: | ||
# iterate over all possible values of mantissa | ||
# i=0, j=1 | ||
# i=1, j=10,11 | ||
# i=2, j=100,101,110,111 | ||
# and so on | ||
for i in range(mbits): | ||
for mantissa_cmp in range(1 << i, 1 << (i+1)): | ||
# left shift mantissa until it overflows (create an implicit 1) | ||
# subtract exponent by the same amount | ||
left_shift = mbits - i | ||
mantissa_f32 = (mantissa_cmp - (1 << i)) << (left_shift + MBITS_F32 - mbits) | ||
exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 | ||
|
||
# we can update this in-place since the values won't overlap | ||
mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = exp_biased_f32 | mantissa_f32 | ||
|
||
result = torch.where(denormal_mask, mantissa_lp_int32, result) | ||
|
||
# add sign back | ||
sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) | ||
result = result | sign_f32 | ||
|
||
return result.view(torch.float) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we have a docblock?