@@ -468,45 +468,8 @@ def create_weights(self, module: torch.nn.Module):
468468
469469    def  load_weights (self , module : torch .nn .Module , weights : List [Dict ],
470470                     weight_loading_mode : MoEWeightLoadingMode ):
471- 
472-         if  get_sm_version () ==  100 :
473-             expert_ids  =  set (module .initial_local_expert_ids )
474-             if  self .need_load_shared_weights (module ):
475-                 expert_ids .update (
476-                     module .layer_load_balancer .get_load_expert_ids ())
477-             for  name  in  list (weights .keys ()):
478-                 if  name .endswith ("weight_scale_inv" ):
479-                     if  int (name .split ("." )[0 ]) not  in expert_ids :
480-                         continue 
481-                     weight_name  =  name .replace ("weight_scale_inv" , "weight" )
482-                     logger .debug (f"Resmoothing { weight_name }  )
483-                     weight  =  weights [weight_name ][:]
484-                     scale  =  weights [name ][:]
485-                     weights [weight_name ], weights [name ] =  resmooth_to_fp8_e8m0 (
486-                         weight , scale )
487471        super ().load_weights (module , weights , weight_loading_mode )
488472
489-         if  get_sm_version () ==  100 :
490-             transfromed_w3_w1_scale  =  transform_sf_into_required_layout (
491-                 module .quant_scales [0 ],
492-                 mn = module .w3_w1_weight .shape [1 ],
493-                 k = module .w3_w1_weight .shape [2 ],
494-                 recipe = (1 , 128 , 128 ),
495-                 num_groups = module .w3_w1_weight .shape [0 ],
496-                 is_sfa = False )
497-             module .w3_w1_weight_scaling_factor  =  nn .Parameter (
498-                 transfromed_w3_w1_scale , requires_grad = False )
499-             transfromed_w2_scale  =  transform_sf_into_required_layout (
500-                 module .quant_scales [1 ],
501-                 mn = module .w2_weight .shape [1 ],
502-                 k = module .w2_weight .shape [2 ],
503-                 recipe = (1 , 128 , 128 ),
504-                 num_groups = module .w3_w1_weight .shape [0 ],
505-                 is_sfa = False )
506-             module .w2_weight_scaling_factor  =  nn .Parameter (transfromed_w2_scale ,
507-                                                            requires_grad = False )
508-             self .setup_quant_scales (module )
509- 
510473    def  setup_quant_scales (self , module : torch .nn .Module ):
511474        module .quant_scales  =  FusedMoEQuantScalesDeepSeekFP8BlockScales (
512475            fc_weight_scales = module .w3_w1_weight_scaling_factor ,
@@ -603,6 +566,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
603566            })
604567
605568
569+ class  DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm (
570+         DeepSeekFP8BlockScalesFusedMoEMethod ):
571+ 
572+     def  load_weights (self , module : torch .nn .Module , weights : List [Dict ],
573+                      weight_loading_mode : MoEWeightLoadingMode ):
574+         if  get_sm_version () ==  100 :
575+             expert_ids  =  set (module .initial_local_expert_ids )
576+             if  self .need_load_shared_weights (module ):
577+                 expert_ids .update (
578+                     module .layer_load_balancer .get_load_expert_ids ())
579+             for  name  in  list (weights .keys ()):
580+                 if  name .endswith ("weight_scale_inv" ):
581+                     if  int (name .split ("." )[0 ]) not  in expert_ids :
582+                         continue 
583+                     weight_name  =  name .replace ("weight_scale_inv" , "weight" )
584+                     logger .debug (f"Resmoothing { weight_name }  )
585+                     weight  =  weights [weight_name ][:]
586+                     scale  =  weights [name ][:]
587+                     weights [weight_name ], weights [name ] =  resmooth_to_fp8_e8m0 (
588+                         weight , scale )
589+         super ().load_weights (module , weights , weight_loading_mode )
590+ 
591+         if  get_sm_version () ==  100 :
592+             transfromed_w3_w1_scale  =  transform_sf_into_required_layout (
593+                 module .quant_scales [0 ],
594+                 mn = module .w3_w1_weight .shape [1 ],
595+                 k = module .w3_w1_weight .shape [2 ],
596+                 recipe = (1 , 128 , 128 ),
597+                 num_groups = module .w3_w1_weight .shape [0 ],
598+                 is_sfa = False )
599+             module .w3_w1_weight_scaling_factor  =  nn .Parameter (
600+                 transfromed_w3_w1_scale , requires_grad = False )
601+             transfromed_w2_scale  =  transform_sf_into_required_layout (
602+                 module .quant_scales [1 ],
603+                 mn = module .w2_weight .shape [1 ],
604+                 k = module .w2_weight .shape [2 ],
605+                 recipe = (1 , 128 , 128 ),
606+                 num_groups = module .w3_w1_weight .shape [0 ],
607+                 is_sfa = False )
608+             module .w2_weight_scaling_factor  =  nn .Parameter (transfromed_w2_scale ,
609+                                                            requires_grad = False )
610+             self .setup_quant_scales (module )
611+ 
612+ 
606613class  WInt4AFP8FusedMoEMethod (FusedMoEMethodBase ):
607614
608615    def  create_weights (self , module : torch .nn .Module ):
0 commit comments