Skip to content

Commit bdc02ca

Browse files
committed
enable modular experts for compressed tensor marlin wn16 moe
Signed-off-by: Lu Fang <[email protected]>
1 parent 35d801f commit bdc02ca

File tree

2 files changed

+54
-6
lines changed

2 files changed

+54
-6
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,9 @@ def batched_fused_marlin_moe(
501501
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
502502
def __init__(self, quant_config: FusedMoEQuantConfig):
503503
# TODO (varun) : Enable activation quantization
504-
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
504+
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
505+
"Supports only mxfp4_w4a16 or int4_w4a16"
506+
)
505507
super().__init__(quant_config)
506508

507509
def moe_problem_size(
@@ -616,7 +618,11 @@ def apply(
616618
gating_output=None,
617619
topk_weights=topk_weights,
618620
topk_ids=topk_ids,
619-
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
621+
quant_type_id=(
622+
scalar_types.uint4b8.id
623+
if self.quant_config.use_int4_w4a16
624+
else scalar_types.float4_e2m1f.id
625+
), # works only for w4a16
620626
apply_router_weight_on_input=apply_router_weight_on_input,
621627
global_num_experts=global_num_experts,
622628
activation=activation,
@@ -720,8 +726,11 @@ def apply(
720726
w1_scale=self.w1_scale,
721727
w2_scale=self.w2_scale,
722728
gating_output=None,
723-
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
724-
apply_router_weight_on_input=apply_router_weight_on_input,
729+
quant_type_id=(
730+
scalar_types.uint4b8.id
731+
if self.quant_config.use_int4_w4a16
732+
else scalar_types.float4_e2m1f.id
733+
), # works only for w4a16
725734
global_num_experts=global_num_experts,
726735
activation=activation,
727736
expert_map=expert_map,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
3636
is_valid_flashinfer_cutlass_fused_moe,
3737
)
38-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
39+
BatchedMarlinExperts,
40+
MarlinExperts,
41+
fused_marlin_moe,
42+
)
3943
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
4044
WNA16_SUPPORTED_BITS,
4145
WNA16_SUPPORTED_TYPES_MAP,
@@ -1562,7 +1566,42 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
15621566
def get_fused_moe_quant_config(
15631567
self, layer: torch.nn.Module
15641568
) -> FusedMoEQuantConfig | None:
1565-
return None
1569+
assert self.num_bits == 4 or self.num_bits == 8
1570+
config_builder = (
1571+
int4_w4a16_moe_quant_config
1572+
if self.num_bits == 4
1573+
else int8_w8a16_moe_quant_config
1574+
)
1575+
1576+
return config_builder(
1577+
w1_scale=layer.w13_weight_scale,
1578+
w2_scale=layer.w2_weight_scale,
1579+
w1_zp=None,
1580+
w2_zp=None,
1581+
block_shape=[0, self.group_size],
1582+
)
1583+
1584+
def select_gemm_impl(
1585+
self,
1586+
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1587+
layer: torch.nn.Module,
1588+
) -> mk.FusedMoEPermuteExpertsUnpermute:
1589+
1590+
layer.w13_weight = layer.w13_weight_packed
1591+
layer.w2_weight = layer.w2_weight_packed
1592+
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
1593+
assert self.moe_quant_config is not None
1594+
if (
1595+
prepare_finalize.activation_format
1596+
== mk.FusedMoEActivationFormat.BatchedExperts
1597+
):
1598+
return BatchedMarlinExperts(
1599+
max_num_tokens=prepare_finalize.max_num_tokens_per_rank(),
1600+
num_dispatchers=prepare_finalize.num_dispatchers(),
1601+
quant_config=self.moe_quant_config,
1602+
)
1603+
else:
1604+
return MarlinExperts(self.moe_quant_config)
15661605

15671606
def apply(
15681607
self,

0 commit comments

Comments
 (0)