Skip to content

Commit d104bf2

Browse files
authored
[MXFP4] Add scale generation utils (#503)
* add mxfp4 scale generation * add rounding test * update * update * update * update * update * add additional case * update
1 parent e88e7d4 commit d104bf2

File tree

4 files changed

+186
-3
lines changed

4 files changed

+186
-3
lines changed

src/compressed_tensors/quantization/quant_args.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
__all__ = [
2626
"FP8_E4M3_DATA",
2727
"FP4_E2M1_DATA",
28+
"BFLOAT16_DATA",
2829
"FloatArgs",
2930
"QuantizationType",
3031
"QuantizationStrategy",
@@ -38,9 +39,9 @@
3839
class FloatArgs:
3940
exponent: int
4041
mantissa: int
41-
bits: int
42-
max: float
43-
min: float
42+
bits: Optional[int] = None
43+
max: Optional[float] = None
44+
min: Optional[float] = None
4445
dtype: Optional[torch.dtype] = None
4546

4647

@@ -76,6 +77,11 @@ class FP8_E4M3_DATA(FloatArgs):
7677
dtype = torch.float8_e4m3fn
7778

7879

80+
class BFLOAT16_DATA(FloatArgs):
81+
exponent = 8
82+
mantissa = 7
83+
84+
7985
class QuantizationType(str, Enum):
8086
"""
8187
Enum storing quantization type options

src/compressed_tensors/quantization/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414

1515
# flake8: noqa
1616
from .helpers import *
17+
from .mxfp4_utils import *
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA
17+
18+
19+
__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"]
20+
21+
# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501
22+
23+
24+
def convert_mxfp4_exp_scale(
25+
scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
26+
) -> torch.Tensor:
27+
"""
28+
Converts mxfp4 scales. Scales are powers of 2, with the
29+
exponents stored in uint8. Converts to dense dtype so that
30+
they can be applied to the weights and activations during QDQ
31+
32+
:param scale: uint8 exponent scale
33+
:param dtype: dense dtype
34+
"""
35+
assert scale.dtype == torch.uint8
36+
scale_exp = scale.to(torch.int32) - 127
37+
scale = 2.00 ** (scale_exp.to(torch.float))
38+
return scale.to(dtype)
39+
40+
41+
def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
42+
"""
43+
Round values to the closest power of 2.
44+
This is done by masking the values with BFLOAT16_SIGN_EXPONENT_MASK
45+
which essentially removes the mantissa and keeps the exponent.
46+
i.e the closest power of 2 for the input_value.
47+
48+
E.g:
49+
0.0825 = 1.32 (mantissa) x 2**-4 (exponent)
50+
0.0825 ==> -4 (exponent) + 127 = 123 = 01111011 (8 bits for bfloat16)
51+
0.0825 ==> 0.32 (mantissa) = 0101001 (7 bits for bfloat16)
52+
0.0825 == 0b01111011_0101001 (bfloat16)
53+
0b01111011_0101001 & 111111111_0000000 == 0b01111011_0000000
54+
Keep the exponent + sign bit to give you the closest power of 2, 0.0625
55+
56+
:param x: tensor to round to closest power of 2
57+
"""
58+
assert x.dtype == torch.bfloat16
59+
x = x.view(torch.uint16).to(torch.int32)
60+
61+
# Find closest power of 2
62+
BFLOAT16_VAL_TO_ADD = 1 << (BFLOAT16_DATA.mantissa - FP4_E2M1_DATA.mantissa - 1)
63+
# Add value to push the value to the next exponent
64+
BFLOAT16_SIGN_EXPONENT_MASK = (
65+
(1 << (BFLOAT16_DATA.exponent + 1)) - 1
66+
) << BFLOAT16_DATA.mantissa
67+
# mask to only keep exponent - we conservatively round down
68+
# to better represent smaller numbers / prevent overflow
69+
block_max_uint = torch.bitwise_and(
70+
x + BFLOAT16_VAL_TO_ADD, BFLOAT16_SIGN_EXPONENT_MASK
71+
)
72+
return block_max_uint.to(torch.uint16).view(torch.bfloat16)
73+
74+
75+
def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
76+
"""
77+
Generate mxfp4 scales. The scales require the following steps
78+
1. Round to the closest power of 2
79+
2. Convert to exponent
80+
3. Store in uint8
81+
82+
Called when calculating qparams using observers.
83+
84+
:param x: tensor to round to closest power of 2
85+
:returns uint8 scales as exponents
86+
"""
87+
# Round to closest power of 2
88+
scale_power_2 = round_to_power_2(x)
89+
# Convert to exponent
90+
scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
91+
# Clamp and store in uint8, as expected by mxfp4
92+
scale_exp = torch.clamp(
93+
scale_exp,
94+
max=torch.iinfo(torch.uint8).max,
95+
min=torch.iinfo(torch.uint8).min,
96+
)
97+
return scale_exp.to(torch.uint8)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import torch
16+
from compressed_tensors.quantization.utils import (
17+
convert_mxfp4_exp_scale,
18+
generate_mxfp4_scales,
19+
round_to_power_2,
20+
)
21+
22+
23+
def test_round_power_2_noise():
24+
powers = torch.Tensor(
25+
[
26+
[2**-10, 2**-9, 2**-8, 2**-7, 2**-6],
27+
[2**-5, 2**-4, 2**-3, 2**-2, 2**-1],
28+
[2**0, 2**1, 2**-10, 2**-9, 2**-8],
29+
[2**-7, 2**-6, 2**-5, 2**-4, 2**-3],
30+
[2**-2, 2**-1, 2**0, 2**1, 2**-10],
31+
]
32+
).to(torch.bfloat16)
33+
34+
noise = torch.rand_like(powers) * 0.2
35+
powers_noisy = powers * (1 + noise)
36+
rounded = round_to_power_2(powers_noisy)
37+
assert torch.equal(rounded, powers)
38+
39+
40+
def test_round_power_2():
41+
x = torch.Tensor(
42+
(
43+
[5.687891, -8.291567, -1.540329, -0.315635, 0.965272],
44+
[-6.944130, 0.073246, -0.451778, 8.571118, -9.856593],
45+
[-0.040571, -0.708509, 2.485657, -4.003352, -0.995600],
46+
[0.224199, 5.032586, -1.309816, -0.621958, 7.290238],
47+
[-9.848001, -0.290731, 1.501562, 0.379829, -5.312081],
48+
)
49+
).to(torch.bfloat16)
50+
x_rounded = torch.Tensor(
51+
(
52+
[4.000000, -8.000000, -1.000000, -0.250000, 1.000000],
53+
[-4.000000, 0.062500, -0.500000, 8.000000, -8.000000],
54+
[-0.0312, -0.500000, 2.000000, -4.000000, -1.000000],
55+
[0.250000, 4.000000, -1.000000, -0.500000, 8.000000],
56+
[-8.000000, -0.250000, 1.000000, 0.250000, -4.000000],
57+
)
58+
).to(torch.bfloat16)
59+
rounded = round_to_power_2(x)
60+
assert torch.equal(rounded, x_rounded)
61+
62+
63+
def test_mxfp4_scales_e2e():
64+
mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880))
65+
66+
x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16)
67+
min_vals = torch.amin(x, dim=-1)
68+
max_vals = torch.amax(x, dim=-1)
69+
70+
min_vals = torch.min(min_vals, torch.zeros_like(min_vals))
71+
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
72+
block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals))
73+
74+
scales_generated = generate_mxfp4_scales(block_max)
75+
converted_ct = convert_mxfp4_exp_scale(scales_generated)
76+
77+
scales_exp = torch.log2(converted_ct)
78+
block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2
79+
assert torch.equal(scales_exp, block_max_exp)

0 commit comments

Comments
 (0)