@@ -88,14 +88,23 @@ def __init__(
8888        self .input_quant  =  self .quant_config .target_scheme_map ["Linear" ].get (
8989            "input_activations" )
9090
91-         if  not  (self .weight_quant .strategy  ==  QuantizationStrategy .TENSOR 
92-                 and  self .input_quant .strategy  ==  QuantizationStrategy .TENSOR ):
91+         per_tensor  =  (self .weight_quant .strategy  ==  QuantizationStrategy .TENSOR 
92+                       and  self .input_quant .strategy 
93+                       ==  QuantizationStrategy .TENSOR )
94+         per_channel  =  (
95+             self .weight_quant .strategy  ==  QuantizationStrategy .CHANNEL 
96+             and  self .input_quant .strategy  ==  QuantizationStrategy .TOKEN )
97+         if  not  (per_tensor  or  per_channel ):
9398            raise  ValueError (
94-                 "For FP8 Fused MoE layers, only  per-tensor scales  " 
95-                 "for weights and activations are supported . Found " 
99+                 "For FP8 Fused MoE layers, we require  per tensor  " 
100+                 "or channelwise, dynamic per token quantization . Found " 
96101                f"{ self .weight_quant }  , { self .input_quant }  " )
97102
98103        self .static_input_scales  =  not  self .input_quant .dynamic 
104+         if  self .static_input_scales  and  per_channel :
105+             raise  ValueError (
106+                 "For FP8 Fused MoE layer, we require either per tensor or " 
107+                 "channelwise, dynamic per token quantization." )
99108
100109    def  create_weights (self , layer : torch .nn .Module , num_experts : int ,
101110                       hidden_size : int , intermediate_size_per_partition : int ,
@@ -123,24 +132,40 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
123132        set_weight_attrs (w2_weight , extra_weight_attrs )
124133
125134        # WEIGHT_SCALES 
126-         # Allocate 2 scales for w1 and w3 respectively. 
127-         # They will be combined to a single scale after weight loading. 
128-         w13_weight_scale  =  torch .nn .Parameter (torch .ones (num_experts ,
129-                                                          2 ,
130-                                                          dtype = torch .float32 ),
131-                                               requires_grad = False )
132-         layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
135+         if  self .weight_quant .strategy  ==  QuantizationStrategy .TENSOR :
136+             # Allocate 2 scales for w1 and w3 respectively. 
137+             # They are combined to a single scale after weight loading. 
138+             w13_weight_scale  =  torch .nn .Parameter (torch .ones (
139+                 num_experts , 2 , dtype = torch .float32 ),
140+                                                   requires_grad = False )
141+             layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
142+             w2_weight_scale  =  torch .nn .Parameter (torch .ones (
143+                 num_experts , dtype = torch .float32 ),
144+                                                  requires_grad = False )
145+             layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
146+             # Add PER-TENSOR quantization for FusedMoE.weight_loader. 
147+             extra_weight_attrs .update (
148+                 {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
149+             set_weight_attrs (w13_weight_scale , extra_weight_attrs )
150+             set_weight_attrs (w2_weight_scale , extra_weight_attrs )
133151
134-         w2_weight_scale  =  torch .nn .Parameter (torch .ones (num_experts ,
135-                                                         dtype = torch .float32 ),
136-                                              requires_grad = False )
137-         layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
138-         # Add the quantization method used (per tensor/grouped/channel) 
139-         # to ensure the weight scales are loaded in properly 
140-         extra_weight_attrs .update (
141-             {"quant_method" : FusedMoeWeightScaleSupported .TENSOR .value })
142-         set_weight_attrs (w13_weight_scale , extra_weight_attrs )
143-         set_weight_attrs (w2_weight_scale , extra_weight_attrs )
152+         elif  self .weight_quant .strategy  ==  QuantizationStrategy .CHANNEL :
153+             w13_weight_scale  =  torch .nn .Parameter (torch .ones (
154+                 num_experts ,
155+                 2  *  intermediate_size_per_partition ,
156+                 1 ,
157+                 dtype = torch .float32 ),
158+                                                   requires_grad = False )
159+             layer .register_parameter ("w13_weight_scale" , w13_weight_scale )
160+             w2_weight_scale  =  torch .nn .Parameter (torch .ones (
161+                 num_experts , hidden_size , 1 , dtype = torch .float32 ),
162+                                                  requires_grad = False )
163+             layer .register_parameter ("w2_weight_scale" , w2_weight_scale )
164+             # Add PER-CHANNEL quantization for FusedMoE.weight_loader. 
165+             extra_weight_attrs .update (
166+                 {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value })
167+             set_weight_attrs (w13_weight_scale , extra_weight_attrs )
168+             set_weight_attrs (w2_weight_scale , extra_weight_attrs )
144169
145170        # INPUT_SCALES 
146171        if  self .static_input_scales :
@@ -163,6 +188,7 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
163188        # Fp8 moe kernels require a single activation scale. 
164189        # We take the max of all the scales in case they differ. 
165190        if  self .static_input_scales :
191+             assert  self .input_quant .strategy  ==  QuantizationStrategy .TENSOR 
166192            if  (layer .w13_input_scale  is  None  or  layer .w2_input_scale  is  None ):
167193                raise  ValueError (
168194                    "QuantConfig has static quantization, but found " 
@@ -204,24 +230,25 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
204230                layer .w2_input_scale  =  torch .nn .Parameter (w2_input_scale ,
205231                                                          requires_grad = False )
206232
207-         # Fp8 moe kernel needs single weight scale for w13 per expert. 
208-         # We take the max then dequant and requant each expert. 
209-         assert  layer .w13_weight_scale  is  not   None 
210-         shard_size  =  layer .intermediate_size_per_partition 
211-         max_w13_scales  =  layer .w13_weight_scale .max (dim = 1 ).values 
212-         for  expert_id  in  range (layer .local_num_experts ):
213-             start  =  0 
214-             for  shard_id  in  range (2 ):
215-                 dq_weight  =  per_tensor_dequantize (
216-                     layer .w13_weight [expert_id ][start :start  +  shard_size , :],
217-                     layer .w13_weight_scale [expert_id ][shard_id ])
218-                 layer .w13_weight [expert_id ][
219-                     start :start  +  shard_size , :], _  =  ops .scaled_fp8_quant (
220-                         dq_weight , max_w13_scales [expert_id ])
221-                 start  +=  shard_size 
222- 
223-         layer .w13_weight_scale  =  torch .nn .Parameter (max_w13_scales ,
224-                                                     requires_grad = False )
233+         # For Per-TENSOR case, Fp8 moe kernel needs single weight scale 
234+         # for w13 per expert. Use max then dequant and requant each expert. 
235+         if  self .weight_quant .strategy  ==  QuantizationStrategy .TENSOR :
236+             assert  layer .w13_weight_scale  is  not   None 
237+             shard_size  =  layer .intermediate_size_per_partition 
238+             max_w13_scales  =  layer .w13_weight_scale .max (dim = 1 ).values 
239+             for  expert_id  in  range (layer .local_num_experts ):
240+                 start  =  0 
241+                 for  shard_id  in  range (2 ):
242+                     dq_weight  =  per_tensor_dequantize (
243+                         layer .w13_weight [expert_id ][start :start  + 
244+                                                     shard_size , :],
245+                         layer .w13_weight_scale [expert_id ][shard_id ])
246+                     layer .w13_weight [expert_id ][
247+                         start :start  +  shard_size , :], _  =  ops .scaled_fp8_quant (
248+                             dq_weight , max_w13_scales [expert_id ])
249+                     start  +=  shard_size 
250+             layer .w13_weight_scale  =  torch .nn .Parameter (max_w13_scales ,
251+                                                         requires_grad = False )
225252
226253    def  apply (
227254        self ,
@@ -265,6 +292,8 @@ def apply(
265292            activation = activation ,
266293            apply_router_weight_on_input = apply_router_weight_on_input ,
267294            use_fp8_w8a8 = True ,
295+             per_channel_quant = self .weight_quant .strategy  == 
296+             QuantizationStrategy .CHANNEL ,
268297            global_num_experts = global_num_experts ,
269298            expert_map = expert_map ,
270299            w1_scale = layer .w13_weight_scale ,
0 commit comments