Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 68 additions & 7 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,11 +499,35 @@ def batched_fused_marlin_moe(


class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
"Supports only mxfp4_w4a16 or int4_w4a16"
)
self.w13_g_idx = w13_g_idx
self.w2_g_idx = w2_g_idx
self.w13_g_idx_sort_indices = w13_g_idx_sort_indices
self.w2_g_idx_sort_indices = w2_g_idx_sort_indices
self.is_k_full = is_k_full
super().__init__(quant_config)

@property
def quant_type_id(self) -> int:
# uint4b8 will be set for int4 weight and float4_e2m1f will be used for mxfp4
return (
scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16
else scalar_types.float4_e2m1f.id
)

def moe_problem_size(
self,
a1: torch.Tensor,
Expand Down Expand Up @@ -533,8 +557,23 @@ def moe_problem_size(


class MarlinExperts(MarlinExpertsBase):
def __init__(self, quant_config: FusedMoEQuantConfig):
super().__init__(quant_config)
def __init__(
self,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)

def supports_expert_map(self) -> bool:
return True
Expand Down Expand Up @@ -616,7 +655,7 @@ def apply(
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
Expand All @@ -628,6 +667,11 @@ def apply(
# output buffer allocation. Please refer to workspace_shapes().
intermediate_cache13=workspace2,
intermediate_cache2=workspace13,
g_idx1=self.w13_g_idx,
g_idx2=self.w2_g_idx,
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)

def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
Expand All @@ -650,8 +694,20 @@ def __init__(
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
w13_g_idx: torch.Tensor | None = None,
w2_g_idx: torch.Tensor | None = None,
w13_g_idx_sort_indices: torch.Tensor | None = None,
w2_g_idx_sort_indices: torch.Tensor | None = None,
is_k_full: bool = True,
):
super().__init__(quant_config)
super().__init__(
quant_config,
w13_g_idx,
w2_g_idx,
w13_g_idx_sort_indices,
w2_g_idx_sort_indices,
is_k_full,
)
self.max_num_tokens = max_num_tokens
self.num_dispatchers = num_dispatchers

Expand Down Expand Up @@ -720,12 +776,17 @@ def apply(
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
quant_type_id=self.quant_type_id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
output=output,
intermediate_cache13=workspace13,
intermediate_cache2=workspace2,
g_idx1=self.w13_g_idx,
g_idx2=self.w2_g_idx,
sort_indices1=self.w13_g_idx_sort_indices,
sort_indices2=self.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
Expand Down Expand Up @@ -1578,7 +1582,51 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
if self.num_bits != 4:
return None
return int4_w4a16_moe_quant_config(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)

def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:
assert self.num_bits == 4, "only supporting w4"
layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
return BatchedMarlinExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)
else:
return MarlinExperts(
quant_config=self.moe_quant_config,
w13_g_idx=layer.w13_weight_g_idx,
w2_g_idx=layer.w2_weight_g_idx,
w13_g_idx_sort_indices=layer.w13_g_idx_sort_indices,
w2_g_idx_sort_indices=layer.w2_g_idx_sort_indices,
is_k_full=self.is_k_full,
)

def apply(
self,
Expand Down