Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,9 @@ def batched_fused_marlin_moe(
class MarlinExpertsBase(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(self, quant_config: FusedMoEQuantConfig):
# TODO (varun) : Enable activation quantization
assert quant_config.use_mxfp4_w4a16, "Supports only mxfp4_w4a16"
assert quant_config.use_mxfp4_w4a16 or quant_config.use_int4_w4a16, (
"Supports only mxfp4_w4a16 or int4_w4a16"
)
super().__init__(quant_config)

def moe_problem_size(
Expand Down Expand Up @@ -616,7 +618,11 @@ def apply(
gating_output=None,
topk_weights=topk_weights,
topk_ids=topk_ids,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
quant_type_id=(
scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16
else scalar_types.float4_e2m1f.id
), # works only for w4a16
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
activation=activation,
Expand Down Expand Up @@ -720,8 +726,11 @@ def apply(
w1_scale=self.w1_scale,
w2_scale=self.w2_scale,
gating_output=None,
quant_type_id=scalar_types.float4_e2m1f.id, # works only for w4a16
apply_router_weight_on_input=apply_router_weight_on_input,
quant_type_id=(
scalar_types.uint4b8.id
if self.quant_config.use_int4_w4a16
else scalar_types.float4_e2m1f.id
), # works only for w4a16
global_num_experts=global_num_experts,
activation=activation,
expert_map=expert_map,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
is_valid_flashinfer_cutlass_fused_moe,
)
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe
from vllm.model_executor.layers.fused_moe.fused_marlin_moe import (
BatchedMarlinExperts,
MarlinExperts,
fused_marlin_moe,
)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import ( # noqa
WNA16_SUPPORTED_BITS,
WNA16_SUPPORTED_TYPES_MAP,
Expand Down Expand Up @@ -1562,7 +1566,42 @@
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return None
assert self.num_bits == 4 or self.num_bits == 8
config_builder = (
int4_w4a16_moe_quant_config
if self.num_bits == 4
else int8_w8a16_moe_quant_config
)

return config_builder(
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_zp=None,
w2_zp=None,
block_shape=[0, self.group_size],
)

def select_gemm_impl(
self,
prepare_finalize: mk.FusedMoEPrepareAndFinalize,
layer: torch.nn.Module,
) -> mk.FusedMoEPermuteExpertsUnpermute:

layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]])
assert self.moe_quant_config is not None
if (
prepare_finalize.activation_format
== mk.FusedMoEActivationFormat.BatchedExperts
):
return BatchedMarlinExperts(
max_num_tokens=prepare_finalize.max_num_tokens_per_rank(),
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
return MarlinExperts(self.moe_quant_config)

Check failure on line 1604 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "max_num_tokens" to "BatchedMarlinExperts" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 1604 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "max_num_tokens" to "BatchedMarlinExperts" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 1604 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "max_num_tokens" to "BatchedMarlinExperts" has incompatible type "int | None"; expected "int" [arg-type]

Check failure on line 1604 in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "max_num_tokens" to "BatchedMarlinExperts" has incompatible type "int | None"; expected "int" [arg-type]

def apply(
self,
Expand Down
Loading