@@ -629,45 +629,8 @@ def create_weights(self, module: torch.nn.Module):
629629
630630 def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
631631 weight_loading_mode : MoEWeightLoadingMode ):
632-
633- if get_sm_version () == 100 :
634- expert_ids = set (module .initial_local_expert_ids )
635- if self .need_load_shared_weights (module ):
636- expert_ids .update (
637- module .layer_load_balancer .get_load_expert_ids ())
638- for name in list (weights .keys ()):
639- if name .endswith ("weight_scale_inv" ):
640- if int (name .split ("." )[0 ]) not in expert_ids :
641- continue
642- weight_name = name .replace ("weight_scale_inv" , "weight" )
643- logger .debug (f"Resmoothing { weight_name } " )
644- weight = weights [weight_name ][:]
645- scale = weights [name ][:]
646- weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
647- weight , scale )
648632 super ().load_weights (module , weights , weight_loading_mode )
649633
650- if get_sm_version () == 100 :
651- transfromed_w3_w1_scale = transform_sf_into_required_layout (
652- module .quant_scales [0 ],
653- mn = module .w3_w1_weight .shape [1 ],
654- k = module .w3_w1_weight .shape [2 ],
655- recipe = (1 , 128 , 128 ),
656- num_groups = module .w3_w1_weight .shape [0 ],
657- is_sfa = False )
658- module .w3_w1_weight_scaling_factor = nn .Parameter (
659- transfromed_w3_w1_scale , requires_grad = False )
660- transfromed_w2_scale = transform_sf_into_required_layout (
661- module .quant_scales [1 ],
662- mn = module .w2_weight .shape [1 ],
663- k = module .w2_weight .shape [2 ],
664- recipe = (1 , 128 , 128 ),
665- num_groups = module .w3_w1_weight .shape [0 ],
666- is_sfa = False )
667- module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
668- requires_grad = False )
669- self .setup_quant_scales (module )
670-
671634 def setup_quant_scales (self , module : torch .nn .Module ):
672635 module .quant_scales = FusedMoEQuantScalesDeepSeekFP8BlockScales (
673636 fc_weight_scales = module .w3_w1_weight_scaling_factor ,
@@ -765,6 +728,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
765728 })
766729
767730
731+ class DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm (
732+ DeepSeekFP8BlockScalesFusedMoEMethod ):
733+
734+ def load_weights (self , module : torch .nn .Module , weights : List [Dict ],
735+ weight_loading_mode : MoEWeightLoadingMode ):
736+ if get_sm_version () == 100 :
737+ expert_ids = set (module .initial_local_expert_ids )
738+ if self .need_load_shared_weights (module ):
739+ expert_ids .update (
740+ module .layer_load_balancer .get_load_expert_ids ())
741+ for name in list (weights .keys ()):
742+ if name .endswith ("weight_scale_inv" ):
743+ if int (name .split ("." )[0 ]) not in expert_ids :
744+ continue
745+ weight_name = name .replace ("weight_scale_inv" , "weight" )
746+ logger .debug (f"Resmoothing { weight_name } " )
747+ weight = weights [weight_name ][:]
748+ scale = weights [name ][:]
749+ weights [weight_name ], weights [name ] = resmooth_to_fp8_e8m0 (
750+ weight , scale )
751+ super ().load_weights (module , weights , weight_loading_mode )
752+
753+ if get_sm_version () == 100 :
754+ transfromed_w3_w1_scale = transform_sf_into_required_layout (
755+ module .quant_scales [0 ],
756+ mn = module .w3_w1_weight .shape [1 ],
757+ k = module .w3_w1_weight .shape [2 ],
758+ recipe = (1 , 128 , 128 ),
759+ num_groups = module .w3_w1_weight .shape [0 ],
760+ is_sfa = False )
761+ module .w3_w1_weight_scaling_factor = nn .Parameter (
762+ transfromed_w3_w1_scale , requires_grad = False )
763+ transfromed_w2_scale = transform_sf_into_required_layout (
764+ module .quant_scales [1 ],
765+ mn = module .w2_weight .shape [1 ],
766+ k = module .w2_weight .shape [2 ],
767+ recipe = (1 , 128 , 128 ),
768+ num_groups = module .w3_w1_weight .shape [0 ],
769+ is_sfa = False )
770+ module .w2_weight_scaling_factor = nn .Parameter (transfromed_w2_scale ,
771+ requires_grad = False )
772+ self .setup_quant_scales (module )
773+
774+
768775class WInt4AFP8FusedMoEMethod (FusedMoEMethodBase ):
769776
770777 def create_weights (self , module : torch .nn .Module ):
0 commit comments