@@ -499,11 +499,35 @@ def batched_fused_marlin_moe(
499499
500500
501501class MarlinExpertsBase (mk .FusedMoEPermuteExpertsUnpermute ):
502- def __init__ (self , quant_config : FusedMoEQuantConfig ):
502+ def __init__ (
503+ self ,
504+ quant_config : FusedMoEQuantConfig ,
505+ w13_g_idx : torch .Tensor | None = None ,
506+ w2_g_idx : torch .Tensor | None = None ,
507+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
508+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
509+ is_k_full : bool = True ,
510+ ):
503511 # TODO (varun) : Enable activation quantization
504- assert quant_config .use_mxfp4_w4a16 , "Supports only mxfp4_w4a16"
512+ assert quant_config .use_mxfp4_w4a16 or quant_config .use_int4_w4a16 , (
513+ "Supports only mxfp4_w4a16 or int4_w4a16"
514+ )
515+ self .w13_g_idx = w13_g_idx
516+ self .w2_g_idx = w2_g_idx
517+ self .w13_g_idx_sort_indices = w13_g_idx_sort_indices
518+ self .w2_g_idx_sort_indices = w2_g_idx_sort_indices
519+ self .is_k_full = is_k_full
505520 super ().__init__ (quant_config )
506521
522+ @property
523+ def quant_type_id (self ) -> int :
524+ # uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
525+ return (
526+ scalar_types .uint4b8 .id
527+ if self .quant_config .use_int4_w4a16
528+ else scalar_types .float4_e2m1f .id
529+ )
530+
507531 def moe_problem_size (
508532 self ,
509533 a1 : torch .Tensor ,
@@ -533,8 +557,23 @@ def moe_problem_size(
533557
534558
535559class MarlinExperts (MarlinExpertsBase ):
536- def __init__ (self , quant_config : FusedMoEQuantConfig ):
537- super ().__init__ (quant_config )
560+ def __init__ (
561+ self ,
562+ quant_config : FusedMoEQuantConfig ,
563+ w13_g_idx : torch .Tensor | None = None ,
564+ w2_g_idx : torch .Tensor | None = None ,
565+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
566+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
567+ is_k_full : bool = True ,
568+ ):
569+ super ().__init__ (
570+ quant_config ,
571+ w13_g_idx ,
572+ w2_g_idx ,
573+ w13_g_idx_sort_indices ,
574+ w2_g_idx_sort_indices ,
575+ is_k_full ,
576+ )
538577
539578 def supports_expert_map (self ) -> bool :
540579 return True
@@ -616,7 +655,7 @@ def apply(
616655 gating_output = None ,
617656 topk_weights = topk_weights ,
618657 topk_ids = topk_ids ,
619- quant_type_id = scalar_types . float4_e2m1f . id , # works only for w4a16
658+ quant_type_id = self . quant_type_id ,
620659 apply_router_weight_on_input = apply_router_weight_on_input ,
621660 global_num_experts = global_num_experts ,
622661 activation = activation ,
@@ -628,6 +667,11 @@ def apply(
628667 # output buffer allocation. Please refer to workspace_shapes().
629668 intermediate_cache13 = workspace2 ,
630669 intermediate_cache2 = workspace13 ,
670+ g_idx1 = self .w13_g_idx ,
671+ g_idx2 = self .w2_g_idx ,
672+ sort_indices1 = self .w13_g_idx_sort_indices ,
673+ sort_indices2 = self .w2_g_idx_sort_indices ,
674+ is_k_full = self .is_k_full ,
631675 )
632676
633677 def moe_sum (self , input : torch .Tensor , output : torch .Tensor ) -> None :
@@ -650,8 +694,20 @@ def __init__(
650694 max_num_tokens : int ,
651695 num_dispatchers : int ,
652696 quant_config : FusedMoEQuantConfig ,
697+ w13_g_idx : torch .Tensor | None = None ,
698+ w2_g_idx : torch .Tensor | None = None ,
699+ w13_g_idx_sort_indices : torch .Tensor | None = None ,
700+ w2_g_idx_sort_indices : torch .Tensor | None = None ,
701+ is_k_full : bool = True ,
653702 ):
654- super ().__init__ (quant_config )
703+ super ().__init__ (
704+ quant_config ,
705+ w13_g_idx ,
706+ w2_g_idx ,
707+ w13_g_idx_sort_indices ,
708+ w2_g_idx_sort_indices ,
709+ is_k_full ,
710+ )
655711 self .max_num_tokens = max_num_tokens
656712 self .num_dispatchers = num_dispatchers
657713
@@ -720,12 +776,17 @@ def apply(
720776 w1_scale = self .w1_scale ,
721777 w2_scale = self .w2_scale ,
722778 gating_output = None ,
723- quant_type_id = scalar_types . float4_e2m1f . id , # works only for w4a16
779+ quant_type_id = self . quant_type_id ,
724780 apply_router_weight_on_input = apply_router_weight_on_input ,
725781 global_num_experts = global_num_experts ,
726782 activation = activation ,
727783 expert_map = expert_map ,
728784 output = output ,
729785 intermediate_cache13 = workspace13 ,
730786 intermediate_cache2 = workspace2 ,
787+ g_idx1 = self .w13_g_idx ,
788+ g_idx2 = self .w2_g_idx ,
789+ sort_indices1 = self .w13_g_idx_sort_indices ,
790+ sort_indices2 = self .w2_g_idx_sort_indices ,
791+ is_k_full = self .is_k_full ,
731792 )
0 commit comments