2323from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
2424from fastdeploy .model_executor .layers .moe .fused_moe_backend_base import MoEMethodBase
2525from fastdeploy .model_executor .layers .moe .moe import get_moe_scores
26+ from fastdeploy .model_executor .layers .quantization .weight_only import WeightOnlyConfig
2627from fastdeploy .model_executor .layers .utils import get_tensor
2728from fastdeploy .model_executor .ops .gpu import (
2829 fused_expert_moe ,
2930 moe_expert_dispatch ,
3031 moe_expert_ffn ,
3132 moe_expert_reduce ,
3233)
33- from fastdeploy .model_executor .utils import TensorTracker , free_tensor , set_weight_attrs
34+ from fastdeploy .model_executor .utils import (
35+ TensorTracker ,
36+ free_tensor ,
37+ process_weight_transpose ,
38+ set_weight_attrs ,
39+ weight_fully_copied ,
40+ )
3441
3542
3643class MetaxCutlassMoEMethod (MoEMethodBase ):
@@ -142,18 +149,11 @@ def apply_tp(
142149 1.0 ,
143150 )
144151 else :
145- added_weight_attrs0 = getattr (layer , self .added_weight_attrs [0 ])
146- added_weight_attrs1 = getattr (layer , self .added_weight_attrs [1 ])
147-
148- if self .quant_config .is_checkpoint_bf16 and layer .fd_config .load_config .load_choices == "default_v1" :
149- added_weight_attrs0 = paddle .transpose (added_weight_attrs0 , perm = [0 , 2 , 1 ])
150- added_weight_attrs1 = paddle .transpose (added_weight_attrs1 , perm = [0 , 2 , 1 ])
151-
152152 fused_moe_out = fused_expert_moe (
153153 x ,
154154 gate .weight ,
155- added_weight_attrs0 ,
156- added_weight_attrs1 ,
155+ getattr ( layer , self . added_weight_attrs [ 0 ]) ,
156+ getattr ( layer , self . added_weight_attrs [ 1 ]) ,
157157 None ,
158158 (layer .up_gate_proj_weight_scale if hasattr (layer , "up_gate_proj_weight_scale" ) else None ),
159159 None ,
@@ -177,7 +177,10 @@ class MetaxCutlassWeightOnlyMoEMethod(MetaxCutlassMoEMethod):
177177
178178 def __init__ (self , quant_config ):
179179 super ().__init__ (quant_config )
180- self .quant_config = quant_config
180+ if quant_config is None :
181+ self .quant_config = WeightOnlyConfig (algo = "weight_only_int8" , is_checkpoint_bf16 = True )
182+ else :
183+ self .quant_config = quant_config
181184 self .moe_quant_type = self .quant_config .algo
182185 self .pack_num = 1
183186 self .weight_only_linear_arch = os .getenv ("FLAGS_weight_only_linear_arch" )
@@ -252,33 +255,61 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
252255 ]
253256 self .up_gate_proj_scale_shape = [layer .num_local_experts , layer .moe_intermediate_size * 2 ]
254257 self .down_proj_scale_shape = [layer .num_local_experts , layer .hidden_size ]
258+ self .model_format = extra_weight_attrs .get ("model_format" )
255259 # TODO(bukejiyu): remove v1 loader check when v0 loader is removed
256260 if self .quant_config .is_checkpoint_bf16 and layer .fd_config .load_config .load_choices == "default_v1" :
261+ if self .model_format != "torch" :
262+ up_gate_proj_weight_shape = [
263+ layer .num_local_experts ,
264+ layer .hidden_size ,
265+ layer .moe_intermediate_size * 2 ,
266+ ]
267+ down_proj_weight_shape = [layer .num_local_experts , layer .moe_intermediate_size , layer .hidden_size ]
268+ up_gate_proj_attrs = {
269+ ** extra_weight_attrs ,
270+ "tensor_track" : TensorTracker (shape = up_gate_proj_weight_shape , output_dim = True ),
271+ }
272+ down_proj_attrs = {
273+ ** extra_weight_attrs ,
274+ "tensor_track" : TensorTracker (shape = down_proj_weight_shape , output_dim = False ),
275+ }
276+ else :
277+ up_gate_proj_weight_shape = [
278+ layer .num_local_experts ,
279+ layer .moe_intermediate_size * 2 ,
280+ layer .hidden_size ,
281+ ]
282+ down_proj_weight_shape = [layer .num_local_experts , layer .hidden_size , layer .moe_intermediate_size ]
283+ up_gate_proj_attrs = {
284+ ** extra_weight_attrs ,
285+ "tensor_track" : TensorTracker (shape = up_gate_proj_weight_shape , output_dim = False ),
286+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
287+ }
288+ down_proj_attrs = {
289+ ** extra_weight_attrs ,
290+ "tensor_track" : TensorTracker (shape = down_proj_weight_shape , output_dim = True ),
291+ "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 },
292+ }
293+
257294 layer .up_gate_proj_weight = layer .create_parameter (
258- shape = [ layer . num_local_experts , layer . hidden_size , layer . moe_intermediate_size * 2 ] ,
295+ shape = up_gate_proj_weight_shape ,
259296 dtype = layer .weight_dtype ,
260297 default_initializer = paddle .nn .initializer .Constant (0 ),
261298 )
262299
263300 layer .down_proj_weight = layer .create_parameter (
264- shape = [ layer . num_local_experts , layer . moe_intermediate_size , layer . hidden_size ] ,
301+ shape = down_proj_weight_shape ,
265302 dtype = layer .weight_dtype ,
266303 default_initializer = paddle .nn .initializer .Constant (0 ),
267304 )
268- extra_weight_attrs ["weight_need_transpose" ] = extra_weight_attrs .get ("model_format" ) == "torch"
305+ # extra_weight_attrs["weight_need_transpose"] = extra_weight_attrs.get("model_format") == "torch"
269306 set_weight_attrs (
270307 layer .up_gate_proj_weight ,
271- {
272- ** extra_weight_attrs ,
273- "tensor_track" : TensorTracker (shape = layer .up_gate_proj_weight .shape , output_dim = True ),
274- },
308+ up_gate_proj_attrs ,
275309 )
276310 set_weight_attrs (
277311 layer .down_proj_weight ,
278- {
279- ** extra_weight_attrs ,
280- "tensor_track" : TensorTracker (shape = layer .down_proj_weight .shape , output_dim = False ),
281- },
312+ down_proj_attrs ,
282313 )
283314 else :
284315 self .weight_dtype = "int8"
@@ -325,7 +356,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
325356 default_initializer = paddle .nn .initializer .Constant (0 ),
326357 ),
327358 )
328- extra_weight_attrs ["weight_need_transpose" ] = not extra_weight_attrs .get ("model_format" ) == "torch"
359+ # extra_weight_attrs["weight_need_transpose"] = not extra_weight_attrs.get("model_format") == "torch"
329360 moe_extra_weight_attrs = {** extra_weight_attrs , "SHARD_ID_TO_SHARDED_DIM" : {"gate" : 0 , "down" : 1 , "up" : 0 }}
330361 set_weight_attrs (layer .up_gate_proj_weight , moe_extra_weight_attrs )
331362 set_weight_attrs (layer .down_proj_weight , moe_extra_weight_attrs )
@@ -337,69 +368,71 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
337368 set_weight_attrs (layer .down_proj_weight_scale , scale_extra_weight_attrs )
338369
339370 def process_weights_after_loading (self , layer ):
340- """ """
341- if not self .quant_config .is_checkpoint_bf16 :
342- return
343- weight_id_map = {"gate_up" : 0 , "down" : 1 }
344- if (
345- hasattr (layer .up_gate_proj_weight , "tensor_track" )
346- and layer .up_gate_proj_weight .tensor_track is not None
347- and layer .up_gate_proj_weight .tensor_track .is_fully_copied ()
348- ):
349- weight_type = "gate_up"
350- else :
351- weight_type = "down"
352-
353- # 1.init shape and type
354- # weight
355- weight_name = self .added_weight_attrs [weight_id_map [weight_type ]]
356- unquantized_weight_name = weight_name .replace ("quant_weight" , "weight" )
357- weight_shape = self .up_gate_proj_weight_shape if weight_type == "gate_up" else self .down_proj_weight_shape
358- weight_shape [1 ], weight_shape [2 ] = weight_shape [2 ], weight_shape [1 ]
359- weight_dtype = "int8"
360- # scale
361- scale_name = self .added_scale_attrs [weight_id_map [weight_type ]]
362- scale_shape = self .up_gate_proj_scale_shape if weight_type == "gate_up" else self .down_proj_scale_shape
363- scale_dtype = self .default_dtype
364-
365- # 2.crate tmp tensor
366-
367- weight = paddle .empty (weight_shape , dtype = weight_dtype )
368- scale = paddle .empty (scale_shape , dtype = scale_dtype )
369-
370- # 3.quantize weight
371-
372- for expert_id in range (layer .num_local_experts ):
373- weight [expert_id ], scale [expert_id ] = weight_quantize (
374- getattr (layer , unquantized_weight_name )[expert_id ],
375- algo = self .moe_quant_type ,
376- arch = self .weight_only_linear_arch ,
377- )
371+ def _process_quantize (weight_idx ):
372+ # 1.init shape and type
373+ weight_name = self .added_weight_attrs [weight_idx ]
374+ unquantized_weight_name = weight_name .replace ("quant_weight" , "weight" )
375+ weight_shape = self .up_gate_proj_weight_shape if weight_type == "gate_up" else self .down_proj_weight_shape
376+ transposed_weight_shape = [weight_shape [0 ], weight_shape [2 ], weight_shape [1 ]]
377+ weight_dtype = "int8"
378+ # scale
379+ scale_name = self .added_scale_attrs [weight_idx ]
380+ scale_shape = self .up_gate_proj_scale_shape if weight_type == "gate_up" else self .down_proj_scale_shape
381+ scale_dtype = self .default_dtype
382+
383+ # 2.crate tmp tensor
384+
385+ weight = paddle .empty (transposed_weight_shape , dtype = weight_dtype )
386+ scale = paddle .empty (scale_shape , dtype = scale_dtype )
387+
388+ # 3.quantize weight
389+
390+ for expert_id in range (layer .num_local_experts ):
391+ weight [expert_id ], scale [expert_id ] = weight_quantize (
392+ getattr (layer , unquantized_weight_name )[expert_id ],
393+ algo = self .moe_quant_type ,
394+ arch = self .weight_only_linear_arch ,
395+ )
378396
379- free_tensor (getattr (layer , unquantized_weight_name ))
397+ free_tensor (getattr (layer , unquantized_weight_name ))
380398
381- # create weight
382- setattr (
383- layer ,
384- weight_name ,
385- layer .create_parameter (
386- shape = weight_shape ,
387- dtype = weight_dtype ,
388- default_initializer = paddle .nn .initializer .Constant (0 ),
389- ),
390- )
391- # create scale
392- setattr (
393- layer ,
394- scale_name ,
395- layer .create_parameter (
396- shape = scale_shape ,
397- dtype = scale_dtype ,
398- default_initializer = paddle .nn .initializer .Constant (0 ),
399- ),
400- )
401- getattr (layer , weight_name ).copy_ (weight , False )
402- getattr (layer , scale_name ).copy_ (scale , False )
399+ setattr (
400+ layer ,
401+ weight_name ,
402+ layer .create_parameter (
403+ shape = weight_shape ,
404+ dtype = weight_dtype ,
405+ default_initializer = paddle .nn .initializer .Constant (0 ),
406+ ),
407+ )
408+ # create scale
409+ setattr (
410+ layer ,
411+ scale_name ,
412+ layer .create_parameter (
413+ shape = scale_shape ,
414+ dtype = scale_dtype ,
415+ default_initializer = paddle .nn .initializer .Constant (0 ),
416+ ),
417+ )
418+ getattr (layer , weight_name ).copy_ (weight .transpose ([0 , 2 , 1 ]), False )
419+ getattr (layer , scale_name ).copy_ (scale , False )
420+
421+ if self .quant_config .is_checkpoint_bf16 :
422+ weight_id_map = {"gate_up" : 0 , "down" : 1 }
423+ if weight_fully_copied (layer .up_gate_proj_weight ):
424+ weight_type = "gate_up"
425+ else :
426+ weight_type = "down"
427+
428+ if self .model_format == "torch" :
429+ unquantized_weight_name = self .added_weight_attrs [weight_id_map [weight_type ]].replace (
430+ "quant_weight" , "weight"
431+ )
432+ process_weight_transpose (layer , unquantized_weight_name )
433+ _process_quantize (weight_id_map [weight_type ])
434+ else :
435+ return
403436
404437 def process_loaded_weights (self , layer : nn .Layer , state_dict ):
405438 """
0 commit comments