@@ -499,13 +499,34 @@ def batched_fused_marlin_moe(
499499
500500
501501class MarlinExpertsBase (mk .FusedMoEPermuteExpertsUnpermute ):
502- def __init__ (self , quant_config : FusedMoEQuantConfig ):
502+ def __init__ (
503+ self ,
504+ quant_config : FusedMoEQuantConfig ,
505+ w13_g_idx : torch .Tensor | None = None ,
506+ w2_g_idx : torch .Tensor | None = None ,
507+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
508+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
509+ is_k_full : bool = False ,
510+ ):
503511 # TODO (varun) : Enable activation quantization
504512 assert quant_config .use_mxfp4_w4a16 or quant_config .use_int4_w4a16 , (
505513 "Supports only mxfp4_w4a16 or int4_w4a16"
506514 )
515+ self .w13_g_idx = w13_g_idx
516+ self .w2_g_idx = w2_g_idx
517+ self .w13_g_idx_sort_indices = w13_g_idx_sort_indices
518+ self .w2_g_idx_sort_indices = w2_g_idx_sort_indices
519+ self .is_k_full = is_k_full
507520 super ().__init__ (quant_config )
508521
522+ @property
523+ def quant_type_id (self ) -> int :
524+ return (
525+ scalar_types .uint4b8 .id
526+ if self .quant_config .use_int4_w4a16
527+ else scalar_types .float4_e2m1f .id
528+ )
529+
509530 def moe_problem_size (
510531 self ,
511532 a1 : torch .Tensor ,
@@ -535,8 +556,23 @@ def moe_problem_size(
535556
536557
537558class MarlinExperts (MarlinExpertsBase ):
538- def __init__ (self , quant_config : FusedMoEQuantConfig ):
539- super ().__init__ (quant_config )
559+ def __init__ (
560+ self ,
561+ quant_config : FusedMoEQuantConfig ,
562+ w13_g_idx : torch .Tensor | None = None ,
563+ w2_g_idx : torch .Tensor | None = None ,
564+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
565+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
566+ is_k_full : bool = False ,
567+ ):
568+ super ().__init__ (
569+ quant_config ,
570+ w13_g_idx ,
571+ w2_g_idx ,
572+ w13_g_idx_sort_indices ,
573+ w2_g_idx_sort_indices ,
574+ is_k_full ,
575+ )
540576
541577 def supports_expert_map (self ) -> bool :
542578 return True
@@ -618,11 +654,7 @@ def apply(
618654 gating_output = None ,
619655 topk_weights = topk_weights ,
620656 topk_ids = topk_ids ,
621- quant_type_id = (
622- scalar_types .uint4b8 .id
623- if self .quant_config .use_int4_w4a16
624- else scalar_types .float4_e2m1f .id
625- ), # works only for w4a16
657+ quant_type_id = self .quant_type_id ,
626658 apply_router_weight_on_input = apply_router_weight_on_input ,
627659 global_num_experts = global_num_experts ,
628660 activation = activation ,
@@ -634,6 +666,10 @@ def apply(
634666 # output buffer allocation. Please refer to workspace_shapes().
635667 intermediate_cache13 = workspace2 ,
636668 intermediate_cache2 = workspace13 ,
669+ g_idx1 = self .w13_g_idx ,
670+ g_idx2 = self .w2_g_idx ,
671+ sort_indices1 = self .w13_g_idx_sort_indices ,
672+ sort_indices2 = self .w2_g_idx_sort_indices ,
637673 )
638674
639675 def moe_sum (self , input : torch .Tensor , output : torch .Tensor ) -> None :
@@ -656,8 +692,20 @@ def __init__(
656692 max_num_tokens : int ,
657693 num_dispatchers : int ,
658694 quant_config : FusedMoEQuantConfig ,
695+ w13_g_idx : torch .Tensor | None = None ,
696+ w2_g_idx : torch .Tensor | None = None ,
697+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
698+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
699+ is_k_full : bool = False ,
659700 ):
660- super ().__init__ (quant_config )
701+ super ().__init__ (
702+ quant_config ,
703+ w13_g_idx ,
704+ w2_g_idx ,
705+ w13_g_idx_sort_indices ,
706+ w2_g_idx_sort_indices ,
707+ is_k_full ,
708+ )
661709 self .max_num_tokens = max_num_tokens
662710 self .num_dispatchers = num_dispatchers
663711
@@ -726,16 +774,17 @@ def apply(
726774 w1_scale = self .w1_scale ,
727775 w2_scale = self .w2_scale ,
728776 gating_output = None ,
729- quant_type_id = (
730- scalar_types .uint4b8 .id
731- if self .quant_config .use_int4_w4a16
732- else scalar_types .float4_e2m1f .id
733- ), # works only for w4a16
777+ quant_type_id = self .quant_type_id ,
734778 apply_router_weight_on_input = apply_router_weight_on_input ,
735779 global_num_experts = global_num_experts ,
736780 activation = activation ,
737781 expert_map = expert_map ,
738782 output = output ,
739783 intermediate_cache13 = workspace13 ,
740784 intermediate_cache2 = workspace2 ,
785+ g_idx1 = self .w13_g_idx ,
786+ g_idx2 = self .w2_g_idx ,
787+ sort_indices1 = self .w13_g_idx_sort_indices ,
788+ sort_indices2 = self .w2_g_idx_sort_indices ,
789+ is_k_full = self .is_k_full ,
741790 )
0 commit comments