Skip to content

Commit e4fff80

Browse files
committed
enable modular experts for compressed tensor marlin wn16 moe
Signed-off-by: Lu Fang <[email protected]>
1 parent 35d801f commit e4fff80

File tree

3 files changed

+69
-7
lines changed

3 files changed

+69
-7
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,9 @@ def batched_fused_marlin_moe(
501501
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
502502
def __init__(self, quant_config: FusedMoEQuantConfig):
503503
# TODO (varun) : Enable activation quantization
504-
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
504+
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
505+
"Supports only mxfp4_w4a16 or int4_w4a16"
506+
)
505507
super().__init__(quant_config)
506508

507509
def moe_problem_size(
@@ -720,8 +722,11 @@ def apply(
720722
w1_scale=self.w1_scale,
721723
w2_scale=self.w2_scale,
722724
gating_output=None,
723-
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
724-
apply_router_weight_on_input=apply_router_weight_on_input,
725+
quant_type_id=(
726+
scalar_types.uint4b8.id
727+
if self.quant_config.use_int4_w4a16
728+
else scalar_types.float4_e2m1f.id
729+
), # works only for w4a16 or mxfp4_w4a16
725730
global_num_experts=global_num_experts,
726731
activation=activation,
727732
expert_map=expert_map,

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,11 +432,21 @@ def apply(
432432
zero_expert_num=zero_expert_num,
433433
zero_expert_type=zero_expert_type,
434434
)
435+
w1 = (
436+
layer.w13_weight_packed
437+
if hasattr(layer, "w13_weight_packed")
438+
else layer.w13_weight
439+
)
440+
w2 = (
441+
layer.w2_weight_packed
442+
if hasattr(layer, "w2_weight_packed")
443+
else layer.w2_weight
444+
)
435445

436446
result = self.fused_experts(
437447
hidden_states=x,
438-
w1=layer.w13_weight,
439-
w2=layer.w2_weight,
448+
w1=w1,
449+
w2=w2,
440450
topk_weights=topk_weights,
441451
topk_ids=topk_ids,
442452
inplace=self.allow_inplace,

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@
3535
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
3636
is_valid_flashinfer_cutlass_fused_moe,
3737
)
38-
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
38+
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
39+
BatchedMarlinExperts,
40+
MarlinExperts,
41+
fused_marlin_moe,
42+
)
3943
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
4044
WNA16_SUPPORTED_BITS,
4145
WNA16_SUPPORTED_TYPES_MAP,
@@ -1562,7 +1566,50 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
15621566
def get_fused_moe_quant_config(
15631567
self, layer: torch.nn.Module
15641568
) -> FusedMoEQuantConfig | None:
1565-
return None
1569+
assert self.num_bits == 4 or self.num_bits == 8
1570+
config_builder = (
1571+
int4_w4a16_moe_quant_config
1572+
if self.num_bits == 4
1573+
else int8_w8a16_moe_quant_config
1574+
)
1575+
1576+
return config_builder(
1577+
w1_scale=layer.w13_weight_scale,
1578+
w2_scale=layer.w2_weight_scale,
1579+
w1_zp=None,
1580+
w2_zp=None,
1581+
block_shape=[0, self.group_size],
1582+
)
1583+
1584+
def select_gemm_impl(
1585+
self,
1586+
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
1587+
layer: torch.nn.Module,
1588+
) -> mk.FusedMoEPermuteExpertsUnpermute:
1589+
if (
1590+
prepare_finalize.activation_format
1591+
== mk.FusedMoEActivationFormat.BatchedExperts
1592+
):
1593+
return BatchedMarlinExperts(
1594+
max_num_tokens=prepare_finalize.max_num_tokens_per_rank(),
1595+
num_dispatchers=prepare_finalize.num_dispatchers(),
1596+
quant_config=self.moe_quant_config,
1597+
)
1598+
else:
1599+
layer.w13_weight = (
1600+
self.w13_weight_triton_tensor
1601+
if layer.w13_weight is None
1602+
else layer.w13_weight
1603+
)
1604+
layer.w2_weight = (
1605+
self.w2_weight_triton_tensor
1606+
if layer.w2_weight is None
1607+
else layer.w2_weight
1608+
)
1609+
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
1610+
1611+
assert self.moe_quant_config is not None
1612+
return MarlinExperts(self.moe_quant_config)
15661613

15671614
def apply(
15681615
self,

0 commit comments

Comments
 (0)