Skip to content

Commit e9d369a

Browse files
committed
pass corret args
Signed-off-by: Lu Fang <[email protected]>
1 parent 60c880b commit e9d369a

File tree

2 files changed

+77
-15
lines changed

2 files changed

+77
-15
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 63 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -499,13 +499,34 @@ def batched_fused_marlin_moe(
499499

500500

501501
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
502-
def __init__(self, quant_config: FusedMoEQuantConfig):
502+
def __init__(
503+
self,
504+
quant_config: FusedMoEQuantConfig,
505+
w13_g_idx: torch.Tensor | None = None,
506+
w2_g_idx: torch.Tensor | None = None,
507+
w13_g_idx_sort_indices: torch.Tensor | None = None,
508+
w2_g_idx_sort_indices: torch.Tensor | None = None,
509+
is_k_full: bool = False,
510+
):
503511
# TODO (varun) : Enable activation quantization
504512
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
505513
"Supports only mxfp4_w4a16 or int4_w4a16"
506514
)
515+
self.w13_g_idx = w13_g_idx
516+
self.w2_g_idx = w2_g_idx
517+
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
518+
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
519+
self.is_k_full = is_k_full
507520
super().__init__(quant_config)
508521

522+
@property
523+
def quant_type_id(self) -> int:
524+
return (
525+
scalar_types.uint4b8.id
526+
if self.quant_config.use_int4_w4a16
527+
else scalar_types.float4_e2m1f.id
528+
)
529+
509530
def moe_problem_size(
510531
self,
511532
a1: torch.Tensor,
@@ -535,8 +556,23 @@ def moe_problem_size(
535556

536557

537558
class MarlinExperts(MarlinExpertsBase):
538-
def __init__(self, quant_config: FusedMoEQuantConfig):
539-
super().__init__(quant_config)
559+
def __init__(
560+
self,
561+
quant_config: FusedMoEQuantConfig,
562+
w13_g_idx: torch.Tensor | None = None,
563+
w2_g_idx: torch.Tensor | None = None,
564+
w13_g_idx_sort_indices: torch.Tensor | None = None,
565+
w2_g_idx_sort_indices: torch.Tensor | None = None,
566+
is_k_full: bool = False,
567+
):
568+
super().__init__(
569+
quant_config,
570+
w13_g_idx,
571+
w2_g_idx,
572+
w13_g_idx_sort_indices,
573+
w2_g_idx_sort_indices,
574+
is_k_full,
575+
)
540576

541577
def supports_expert_map(self) -> bool:
542578
return True
@@ -618,11 +654,7 @@ def apply(
618654
gating_output=None,
619655
topk_weights=topk_weights,
620656
topk_ids=topk_ids,
621-
quant_type_id=(
622-
scalar_types.uint4b8.id
623-
if self.quant_config.use_int4_w4a16
624-
else scalar_types.float4_e2m1f.id
625-
), # works only for w4a16
657+
quant_type_id=self.quant_type_id,
626658
apply_router_weight_on_input=apply_router_weight_on_input,
627659
global_num_experts=global_num_experts,
628660
activation=activation,
@@ -634,6 +666,10 @@ def apply(
634666
# output buffer allocation. Please refer to workspace_shapes().
635667
intermediate_cache13=workspace2,
636668
intermediate_cache2=workspace13,
669+
g_idx1=self.w13_g_idx,
670+
g_idx2=self.w2_g_idx,
671+
sort_indices1=self.w13_g_idx_sort_indices,
672+
sort_indices2=self.w2_g_idx_sort_indices,
637673
)
638674

639675
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
@@ -656,8 +692,20 @@ def __init__(
656692
max_num_tokens: int,
657693
num_dispatchers: int,
658694
quant_config: FusedMoEQuantConfig,
695+
w13_g_idx: torch.Tensor | None = None,
696+
w2_g_idx: torch.Tensor | None = None,
697+
w13_g_idx_sort_indices: torch.Tensor | None = None,
698+
w2_g_idx_sort_indices: torch.Tensor | None = None,
699+
is_k_full: bool = False,
659700
):
660-
super().__init__(quant_config)
701+
super().__init__(
702+
quant_config,
703+
w13_g_idx,
704+
w2_g_idx,
705+
w13_g_idx_sort_indices,
706+
w2_g_idx_sort_indices,
707+
is_k_full,
708+
)
661709
self.max_num_tokens = max_num_tokens
662710
self.num_dispatchers = num_dispatchers
663711

@@ -726,16 +774,17 @@ def apply(
726774
w1_scale=self.w1_scale,
727775
w2_scale=self.w2_scale,
728776
gating_output=None,
729-
quant_type_id=(
730-
scalar_types.uint4b8.id
731-
if self.quant_config.use_int4_w4a16
732-
else scalar_types.float4_e2m1f.id
733-
), # works only for w4a16
777+
quant_type_id=self.quant_type_id,
734778
apply_router_weight_on_input=apply_router_weight_on_input,
735779
global_num_experts=global_num_experts,
736780
activation=activation,
737781
expert_map=expert_map,
738782
output=output,
739783
intermediate_cache13=workspace13,
740784
intermediate_cache2=workspace2,
785+
g_idx1=self.w13_g_idx,
786+
g_idx2=self.w2_g_idx,
787+
sort_indices1=self.w13_g_idx_sort_indices,
788+
sort_indices2=self.w2_g_idx_sort_indices,
789+
is_k_full=self.is_k_full,
741790
)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,7 @@ def select_gemm_impl(
15811581
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
15821582
layer: torch.nn.Module,
15831583
) -> mk.FusedMoEPermuteExpertsUnpermute:
1584+
assert self.num_bits == 4, "only supporting w4"
15841585
layer.w13_weight = layer.w13_weight_packed
15851586
layer.w2_weight = layer.w2_weight_packed
15861587
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
@@ -1593,9 +1594,21 @@ def select_gemm_impl(
15931594
max_num_tokens=prepare_finalize.max_num_tokens_per_rank(),
15941595
num_dispatchers=prepare_finalize.num_dispatchers(),
15951596
quant_config=self.moe_quant_config,
1597+
w13_g_idx=layer.w13_weight_g_idx,
1598+
w2_g_idx=layer.w2_g_idx,
1599+
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
1600+
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
1601+
is_k_full=self.is_k_full,
15961602
)
15971603
else:
1598-
return MarlinExperts(self.moe_quant_config)
1604+
return MarlinExperts(
1605+
quant_config=self.moe_quant_config,
1606+
w13_g_idx=layer.w13_weight_g_idx,
1607+
w2_g_idx=layer.w2_g_idx,
1608+
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
1609+
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
1610+
is_k_full=self.is_k_full,
1611+
)
15991612

16001613
def apply(
16011614
self,

0 commit comments

Comments
 (0)