@@ -63,6 +63,18 @@ def is_fp4(quantization_args: QuantizationArgs):
6363 and quantization_args .type == QuantizationType .FLOAT
6464 )
6565
66+ def get_power_of_two (x ):
67+ powers = torch .tensor ([0 , 1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 ], dtype = torch .uint8 ).to (x .device )
68+
69+ # Expand and compute distances
70+ diff = (x .unsqueeze (- 1 ).to (torch .int16 ) - powers .to (torch .int16 )).abs ()
71+
72+ # Find nearest index
73+ nearest_idx = diff .argmin (dim = - 1 )
74+
75+ return powers [nearest_idx ]
76+
77+
6678
6779def calculate_qparams (
6880 min_vals : Tensor ,
@@ -93,33 +105,50 @@ def calculate_qparams(
93105 bit_range = bit_max - bit_min
94106
95107 if is_fp4 (quantization_args = quantization_args ):
96- zp_dtype = FP8_E4M3_DATA .dtype
108+ if quantization_args .group_size == 16 :
109+ zp_dtype = FP8_E4M3_DATA .dtype
110+ else :
111+ # group_size 32
112+ zp_dtype = torch .uint8
97113 else :
98114 zp_dtype = quantization_args .pytorch_dtype ()
99115
100116 if quantization_args .symmetric :
101117 max_val_pos = torch .max (torch .abs (min_vals ), torch .abs (max_vals ))
102118
103- if is_fp4 (quantization_args = quantization_args ) and global_scale is not None :
104- # Conditionally scale the generated local scale by a global_scale
105- scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max )
106- scales = torch .clamp (scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min )
107- scales = scales .to (FP8_E4M3_DATA .dtype )
119+ if is_fp4 (quantization_args = quantization_args ):
120+ if global_scale is not None :
121+ # Conditionally scale the generated local scale by a global_scale
122+ scales = global_scale * (max_val_pos / FP4_E2M1_DATA .max )
123+ scales = torch .clamp (
124+ scales , max = FP8_E4M3_DATA .max , min = FP8_E4M3_DATA .min
125+ )
126+ scales = scales .to (FP8_E4M3_DATA .dtype )
127+ else :
128+
129+ scales = torch .iinfo (torch .uint8 ).max * (max_val_pos ) # / FP4_E2M1_DATA.max)
130+ scales = torch .clamp (
131+ scales ,
132+ max = torch .iinfo (torch .uint8 ).max ,
133+ min = torch .iinfo (torch .uint8 ).min ,
134+ )
135+ scales = scales .to (torch .uint8 )
136+ scales = get_power_of_two (scales )
108137
109138 else :
110139 scales = max_val_pos / (float (bit_range ) / 2 )
111140
112141 # TODO: in the case of MoEs, the global_scale may also be 0/need to be clamped
113- if scales .dtype == FP8_E4M3_DATA .dtype :
114- # torch.clamp not supported for FP8
115- # use the next largest fp8 value from 0
116- scales = torch .where (
117- scales == 0 ,
118- torch .tensor (0.125 , dtype = FP8_E4M3_DATA .dtype , device = device ),
119- scales ,
120- )
121- else :
122- scales = torch .clamp (scales , min = torch .finfo (torch .float32 ).eps )
142+ # if scales.dtype == FP8_E4M3_DATA.dtype:
143+ # torch.clamp not supported for FP8
144+ # use the next largest fp8 value from 0
145+ # scales = torch.where(
146+ # scales == 0,
147+ # torch.tensor(0.125, dtype=FP8_E4M3_DATA.dtype, device=device),
148+ # scales,
149+ # )
150+ # else:
151+ # scales = torch.clamp(scales, min=torch.finfo(torch.float32).eps)
123152
124153 zero_points = torch .zeros (scales .shape , device = device , dtype = min_vals .dtype )
125154 else :
0 commit comments