@@ -64,6 +64,18 @@ def is_fp4(quantization_args: QuantizationArgs):
6464 and quantization_args .type == QuantizationType .FLOAT
6565 )
6666
67+ def get_power_of_two (x ):
68+ powers = torch .tensor ([0 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ], dtype = torch .uint8 ).to (x .device )
69+
70+ # Expand and compute distances
71+ diff = (x .unsqueeze (- 1 ).to (torch .int16 ) - powers .to (torch .int16 )).abs ()
72+
73+ # Find nearest index
74+ nearest_idx = diff .argmin (dim = - 1 )
75+
76+ return powers [nearest_idx ]
77+
78+
6779
6880def calculate_qparams (
6981 min_vals : Tensor ,
@@ -94,33 +106,50 @@ def calculate_qparams(
94106 bit_range = bit_max - bit_min
95107
96108 if is_fp4 (quantization_args = quantization_args ):
97- zp_dtype = FP8_E4M3_DATA .dtype
109+ if quantization_args .group_size == 16 :
110+ zp_dtype = FP8_E4M3_DATA .dtype
111+ else :
112+ # group_size 32
113+ zp_dtype = torch .uint8
98114 else :
99115 zp_dtype = quantization_args .pytorch_dtype ()
100116
101117 if quantization_args .symmetric :
102118 max_val_pos = torch .max (torch .abs (min_vals ), torch .abs (max_vals ))
103119
104- if is_fp4 (quantization_args = quantization_args ) and global_scale is not None :
105- # Conditionally scale the generated local scale by a global_scale
106- scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max )
107- scales = torch .clamp (scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min )
108- scales = scales .to (FP8_E4M3_DATA .dtype )
120+ if is_fp4 (quantization_args = quantization_args ):
121+ if global_scale is not None :
122+ # Conditionally scale the generated local scale by a global_scale
123+ scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max )
124+ scales = torch .clamp (
125+ scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min
126+ )
127+ scales = scales .to (FP8_E4M3_DATA .dtype )
128+ else :
129+
130+ scales = torch .iinfo (torch .uint8 ).max * (max_val_pos ) # / FP4_E2M1_DATA.max)
131+ scales = torch .clamp (
132+ scales ,
133+ max = torch .iinfo (torch .uint8 ).max ,
134+ min = torch .iinfo (torch .uint8 ).min ,
135+ )
136+ scales = scales .to (torch .uint8 )
137+ scales = get_power_of_two (scales )
109138
110139 else :
111140 scales = max_val_pos / (float (bit_range ) / 2 )
112141
113142 # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
114- if scales .dtype == FP8_E4M3_DATA .dtype :
115- # torch.clamp not supported for FP8
116- # use the next largest fp8 value from 0
117- scales = torch .where (
118- scales == 0 ,
119- torch .tensor (0.125 , dtype = FP8_E4M3_DATA .dtype , device = device ),
120- scales ,
121- )
122- else :
123- scales = torch .clamp (scales , min = torch .finfo (torch .float32 ).eps )
143+ # if scales.dtype == FP8_E4M3_DATA.dtype:
144+ # torch.clamp not supported for FP8
145+ # use the next largest fp8 value from 0
146+ # scales = torch.where(
147+ # scales == 0,
148+ # torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
149+ # scales,
150+ # )
151+ # else:
152+ # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
124153
125154 zero_points = torch .zeros (scales .shape , device = device , dtype = min_vals .dtype )
126155 else :
0 commit comments