Skip to content

Commit 24b61bd

Browse files
committed
add mxfp4 calibration support
1 parent 7e4c47d commit 24b61bd

File tree

3 files changed

+53
-18
lines changed

3 files changed

+53
-18
lines changed

src/compressed_tensors/quantization/lifecycle/forward.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ def _quantize(
468468
if global_scale is not None:
469469
scale = scale.to(global_scale.dtype) / global_scale
470470

471+
scale = scale.to(x.dtype) / torch.iinfo(torch.uint8).max
471472
scaled = x / scale
472473

473474
if zero_point is not None:
@@ -501,6 +502,8 @@ def _dequantize(
501502
if global_scale is not None:
502503
scale = scale.to(global_scale.dtype) / global_scale
503504

505+
scale = scale.to(torch.float16) / torch.iinfo(torch.uint8).max
506+
504507
dequant_value = x_q.to(scale.dtype)
505508

506509
if zero_point is not None:
@@ -510,5 +513,4 @@ def _dequantize(
510513

511514
if dtype is not None:
512515
dequant_value = dequant_value.to(dtype)
513-
514516
return dequant_value

src/compressed_tensors/quantization/lifecycle/initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,11 @@ def _initialize_scale_zero_point(
216216
scale_dtype = scale_dtype if scale_dtype is not None else module.weight.dtype
217217

218218
if is_fp4(quantization_args=quantization_args):
219-
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
219+
if quantization_args.group_size == 16:
220+
scale_dtype = zp_dtype = FP8_E4M3_DATA.dtype
221+
else:
222+
# group_size 32
223+
scale_dtype = zp_dtype = torch.uint8
220224
else:
221225
# TODO: consider erroring out in the future as if the dtype if not one of these,
222226
# there is likely bug

src/compressed_tensors/quantization/utils/helpers.py

Lines changed: 45 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,18 @@ def is_fp4(quantization_args: QuantizationArgs):
6363
and quantization_args.type == QuantizationType.FLOAT
6464
)
6565

66+
def get_power_of_two(x):
67+
powers = torch.tensor([0, 1, 2, 4, 8, 16, 32, 64, 128], dtype=torch.uint8).to(x.device)
68+
69+
# Expand and compute distances
70+
diff = (x.unsqueeze(-1).to(torch.int16) - powers.to(torch.int16)).abs()
71+
72+
# Find nearest index
73+
nearest_idx = diff.argmin(dim=-1)
74+
75+
return powers[nearest_idx]
76+
77+
6678

6779
def calculate_qparams(
6880
min_vals: Tensor,
@@ -93,33 +105,50 @@ def calculate_qparams(
93105
bit_range = bit_max - bit_min
94106

95107
if is_fp4(quantization_args=quantization_args):
96-
zp_dtype = FP8_E4M3_DATA.dtype
108+
if quantization_args.group_size == 16:
109+
zp_dtype = FP8_E4M3_DATA.dtype
110+
else:
111+
# group_size 32
112+
zp_dtype = torch.uint8
97113
else:
98114
zp_dtype = quantization_args.pytorch_dtype()
99115

100116
if quantization_args.symmetric:
101117
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
102118

103-
if is_fp4(quantization_args=quantization_args) and global_scale is not None:
104-
# Conditionally scale the generated local scale by a global_scale
105-
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
106-
scales = torch.clamp(scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min)
107-
scales = scales.to(FP8_E4M3_DATA.dtype)
119+
if is_fp4(quantization_args=quantization_args):
120+
if global_scale is not None:
121+
# Conditionally scale the generated local scale by a global_scale
122+
scales = global_scale * (max_val_pos / FP4_E2M1_DATA.max)
123+
scales = torch.clamp(
124+
scales, max=FP8_E4M3_DATA.max, min=FP8_E4M3_DATA.min
125+
)
126+
scales = scales.to(FP8_E4M3_DATA.dtype)
127+
else:
128+
129+
scales = torch.iinfo(torch.uint8).max * (max_val_pos) # / FP4_E2M1_DATA.max)
130+
scales = torch.clamp(
131+
scales,
132+
max=torch.iinfo(torch.uint8).max,
133+
min=torch.iinfo(torch.uint8).min,
134+
)
135+
scales = scales.to(torch.uint8)
136+
scales = get_power_of_two(scales)
108137

109138
else:
110139
scales = max_val_pos / (float(bit_range) / 2)
111140

112141
# TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113-
if scales.dtype == FP8_E4M3_DATA.dtype:
114-
# torch.clamp not supported for FP8
115-
# use the next largest fp8 value from 0
116-
scales = torch.where(
117-
scales == 0,
118-
torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
119-
scales,
120-
)
121-
else:
122-
scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
142+
# if scales.dtype == FP8_E4M3_DATA.dtype:
143+
# torch.clamp not supported for FP8
144+
# use the next largest fp8 value from 0
145+
# scales = torch.where(
146+
# scales == 0,
147+
# torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
148+
# scales,
149+
# )
150+
# else:
151+
# scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
123152

124153
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
125154
else:

0 commit comments

Comments
 (0)