Skip to content

Commit 7e082bc

Browse files
authored
Support DeepEP for Kimi-k2-thinking through enabling gemm selection for compressed-tensor marlin wna16 (vllm-project#28574)
Signed-off-by: Lu Fang <[email protected]>
1 parent dbbe0c7 commit 7e082bc

File tree

2 files changed

+118
-9
lines changed

2 files changed

+118
-9
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -499,11 +499,35 @@ def batched_fused_marlin_moe(
499499

500500

501501
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
502-
def __init__(self, quant_config: FusedMoEQuantConfig):
502+
def __init__(
503+
self,
504+
quant_config: FusedMoEQuantConfig,
505+
w13_g_idx: torch.Tensor | None = None,
506+
w2_g_idx: torch.Tensor | None = None,
507+
w13_g_idx_sort_indices: torch.Tensor | None = None,
508+
w2_g_idx_sort_indices: torch.Tensor | None = None,
509+
is_k_full: bool = True,
510+
):
503511
# TODO (varun) : Enable activation quantization
504-
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
512+
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
513+
"Supports only mxfp4_w4a16 or int4_w4a16"
514+
)
515+
self.w13_g_idx = w13_g_idx
516+
self.w2_g_idx = w2_g_idx
517+
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
518+
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
519+
self.is_k_full = is_k_full
505520
super().__init__(quant_config)
506521

522+
@property
523+
def quant_type_id(self) -> int:
524+
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
525+
return (
526+
scalar_types.uint4b8.id
527+
if self.quant_config.use_int4_w4a16
528+
else scalar_types.float4_e2m1f.id
529+
)
530+
507531
def moe_problem_size(
508532
self,
509533
a1: torch.Tensor,
@@ -533,8 +557,23 @@ def moe_problem_size(
533557

534558

535559
class MarlinExperts(MarlinExpertsBase):
536-
def __init__(self, quant_config: FusedMoEQuantConfig):
537-
super().__init__(quant_config)
560+
def __init__(
561+
self,
562+
quant_config: FusedMoEQuantConfig,
563+
w13_g_idx: torch.Tensor | None = None,
564+
w2_g_idx: torch.Tensor | None = None,
565+
w13_g_idx_sort_indices: torch.Tensor | None = None,
566+
w2_g_idx_sort_indices: torch.Tensor | None = None,
567+
is_k_full: bool = True,
568+
):
569+
super().__init__(
570+
quant_config,
571+
w13_g_idx,
572+
w2_g_idx,
573+
w13_g_idx_sort_indices,
574+
w2_g_idx_sort_indices,
575+
is_k_full,
576+
)
538577

539578
def supports_expert_map(self) -> bool:
540579
return True
@@ -616,7 +655,7 @@ def apply(
616655
gating_output=None,
617656
topk_weights=topk_weights,
618657
topk_ids=topk_ids,
619-
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
658+
quant_type_id=self.quant_type_id,
620659
apply_router_weight_on_input=apply_router_weight_on_input,
621660
global_num_experts=global_num_experts,
622661
activation=activation,
@@ -628,6 +667,11 @@ def apply(
628667
# output buffer allocation. Please refer to workspace_shapes().
629668
intermediate_cache13=workspace2,
630669
intermediate_cache2=workspace13,
670+
g_idx1=self.w13_g_idx,
671+
g_idx2=self.w2_g_idx,
672+
sort_indices1=self.w13_g_idx_sort_indices,
673+
sort_indices2=self.w2_g_idx_sort_indices,
674+
is_k_full=self.is_k_full,
631675
)
632676

633677
def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
@@ -650,8 +694,20 @@ def __init__(
650694
max_num_tokens: int,
651695
num_dispatchers: int,
652696
quant_config: FusedMoEQuantConfig,
697+
w13_g_idx: torch.Tensor | None = None,
698+
w2_g_idx: torch.Tensor | None = None,
699+
w13_g_idx_sort_indices: torch.Tensor | None = None,
700+
w2_g_idx_sort_indices: torch.Tensor | None = None,
701+
is_k_full: bool = True,
653702
):
654-
super().__init__(quant_config)
703+
super().__init__(
704+
quant_config,
705+
w13_g_idx,
706+
w2_g_idx,
707+
w13_g_idx_sort_indices,
708+
w2_g_idx_sort_indices,
709+
is_k_full,
710+
)
655711
self.max_num_tokens = max_num_tokens
656712
self.num_dispatchers = num_dispatchers
657713

@@ -720,12 +776,17 @@ def apply(
720776
w1_scale=self.w1_scale,
721777
w2_scale=self.w2_scale,
722778
gating_output=None,
723-
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
779+
quant_type_id=self.quant_type_id,
724780
apply_router_weight_on_input=apply_router_weight_on_input,
725781
global_num_experts=global_num_experts,
726782
activation=activation,
727783
expert_map=expert_map,
728784
output=output,
729785
intermediate_cache13=workspace13,
730786
intermediate_cache2=workspace2,
787+
g_idx1=self.w13_g_idx,
788+
g_idx2=self.w2_g_idx,
789+
sort_indices1=self.w13_g_idx_sort_indices,
790+
sort_indices2=self.w2_g_idx_sort_indices,
791+
is_k_full=self.is_k_full,
731792
)

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
3636
is_valid_flashinfer_cutlass_fused_moe,
3737
)
38-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
39+
BatchedMarlinExperts,
40+
MarlinExperts,
41+
fused_marlin_moe,
42+
)
3943
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
4044
WNA16_SUPPORTED_BITS,
4145
WNA16_SUPPORTED_TYPES_MAP,
@@ -1578,7 +1582,51 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
15781582
def get_fused_moe_quant_config(
15791583
self, layer: torch.nn.Module
15801584
) -> FusedMoEQuantConfig | None:
1581-
return None
1585+
if self.num_bits != 4:
1586+
return None
1587+
return int4_w4a16_moe_quant_config(
1588+
w1_scale=layer.w13_weight_scale,
1589+
w2_scale=layer.w2_weight_scale,
1590+
w1_zp=None,
1591+
w2_zp=None,
1592+
block_shape=[0, self.group_size],
1593+
)
1594+
1595+
def select_gemm_impl(
1596+
self,
1597+
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1598+
layer: torch.nn.Module,
1599+
) -> mk.FusedMoEPermuteExpertsUnpermute:
1600+
assert self.num_bits == 4, "only supporting w4"
1601+
layer.w13_weight = layer.w13_weight_packed
1602+
layer.w2_weight = layer.w2_weight_packed
1603+
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
1604+
assert self.moe_quant_config is not None
1605+
if (
1606+
prepare_finalize.activation_format
1607+
== mk.FusedMoEActivationFormat.BatchedExperts
1608+
):
1609+
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
1610+
assert max_num_tokens_per_rank is not None
1611+
return BatchedMarlinExperts(
1612+
max_num_tokens=max_num_tokens_per_rank,
1613+
num_dispatchers=prepare_finalize.num_dispatchers(),
1614+
quant_config=self.moe_quant_config,
1615+
w13_g_idx=layer.w13_weight_g_idx,
1616+
w2_g_idx=layer.w2_weight_g_idx,
1617+
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
1618+
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
1619+
is_k_full=self.is_k_full,
1620+
)
1621+
else:
1622+
return MarlinExperts(
1623+
quant_config=self.moe_quant_config,
1624+
w13_g_idx=layer.w13_weight_g_idx,
1625+
w2_g_idx=layer.w2_weight_g_idx,
1626+
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
1627+
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
1628+
is_k_full=self.is_k_full,
1629+
)
15821630

15831631
def apply(
15841632
self,

0 commit comments

Comments
 (0)