Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,10 @@ def backend_to_kernel_cls(
return [AiterExperts]

elif backend == Mxfp4MoeBackend.XPU:
raise NotImplementedError("XPU backend uses XpuMxfp4MoEMethod directly.")
from vllm.model_executor.layers.fused_moe.xpu_fused_moe import XPUExpertsMXFp4

return [XPUExpertsMXFp4]

else:
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")

Expand All @@ -156,6 +159,7 @@ def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
"triton": Mxfp4MoeBackend.TRITON,
"marlin": Mxfp4MoeBackend.MARLIN,
"ck": Mxfp4MoeBackend.CK,
"xpu": Mxfp4MoeBackend.XPU,
}
if backend := mapping.get(runner_backend):
return backend
Expand All @@ -178,6 +182,7 @@ def _get_priority_backends() -> list[Mxfp4MoeBackend]:
Mxfp4MoeBackend.TRITON_UNFUSED,
Mxfp4MoeBackend.MARLIN,
Mxfp4MoeBackend.BATCHED_MARLIN,
Mxfp4MoeBackend.XPU,
]
return _AVAILABLE_BACKENDS

Expand Down Expand Up @@ -351,7 +356,13 @@ def _return_or_raise(
if current_platform.is_xpu():
backend = Mxfp4MoeBackend.XPU
logger.info_once(_make_log_backend(backend))
return backend, None
return _return_or_raise(
Mxfp4MoeBackend.XPU,
config,
kMxfp4Static,
None,
activation_format,
)

if current_platform.is_cuda() or current_platform.is_rocm():
raise NotImplementedError(
Expand Down Expand Up @@ -741,6 +752,16 @@ def _interleave_mxfp4_cutlass_sm90(w):
w13_bias,
w2_bias,
)
elif mxfp4_backend == Mxfp4MoeBackend.XPU:
# No additional transformation needed for XPU backend
return (
w13_weight,
w2_weight,
w13_weight_scale,
w2_weight_scale,
w13_bias,
w2_bias,
)
else:
raise ValueError(
f"Unsupported mxfp4_backend: {mxfp4_backend}: "
Expand Down
30 changes: 30 additions & 0 deletions vllm/model_executor/layers/fused_moe/xpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
QuantKey,
kFp8DynamicTensorSym,
kFp8StaticTensorSym,
kMxfp4Static,
)
from vllm.platforms import current_platform

Expand All @@ -38,6 +39,7 @@ def __init__(
num_dispatchers,
)
self.is_fp8 = False
self.is_mxfp4 = False

@property
def expects_unquantized_inputs(self) -> bool:
Expand Down Expand Up @@ -137,6 +139,7 @@ def apply(
ep_size=self.moe_config.ep_size,
output=output,
is_fp8=self.is_fp8,
is_mxfp4=self.is_mxfp4,
)


Expand All @@ -155,3 +158,30 @@ def __init__(
num_dispatchers,
)
self.is_fp8 = True


class XPUExpertsMXFp4(XPUExperts):
def __init__(
self,
moe_config: FusedMoEConfig,
quant_config: FusedMoEQuantConfig,
max_num_tokens: int | None = None,
num_dispatchers: int | None = None,
):
super().__init__(
moe_config,
quant_config,
max_num_tokens,
num_dispatchers,
)
self.is_mxfp4 = True

@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kMxfp4Static, None),
]
return (weight_key, activation_key) in SUPPORTED_W_A
100 changes: 1 addition & 99 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
FusedMoE,
FusedMoEConfig,
FusedMoEMethodBase,
MoEActivation,
)
from vllm.model_executor.layers.fused_moe import modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import (
Expand All @@ -33,7 +32,6 @@
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import replace_parameter, set_weight_attrs
from vllm.platforms import current_platform

logger = init_logger(__name__)

Expand Down Expand Up @@ -80,10 +78,7 @@ def get_quant_method(
)
return UnquantizedLinearMethod()
elif isinstance(layer, FusedMoE):
if current_platform.is_xpu():
return XpuMxfp4MoEMethod(layer.moe_config)
else:
return Mxfp4MoEMethod(layer.moe_config)
return Mxfp4MoEMethod(layer.moe_config)
elif isinstance(layer, Attention):
logger.debug_once(
"MXFP4 attention layer is not implemented. "
Expand Down Expand Up @@ -420,96 +415,3 @@ def apply_monolithic(
expert_map=layer.expert_map,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
)


class XpuMxfp4MoEMethod(Mxfp4MoEMethod):
def __init__(self, moe_config: FusedMoEConfig):
super().__init__(moe_config)
self.moe_config = moe_config

def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size_per_partition: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
super().create_weights(
layer,
num_experts,
hidden_size,
intermediate_size_per_partition,
params_dtype,
**extra_weight_attrs,
)
self.original_hidden_size = hidden_size

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass

@property
def is_monolithic(self) -> bool:
return True

def apply_monolithic(
self,
layer: FusedMoE,
x: torch.Tensor,
router_logits: torch.Tensor,
) -> torch.Tensor:
assert layer.activation == MoEActivation.SWIGLUOAI, (
"Only swiglu_oai activation is supported for "
f"XPU MXFP4 MoE, not {layer.activation}."
)
from vllm_xpu_kernels.fused_moe_interface import xpu_fused_moe

M, _ = x.size()
routing_weights = torch.empty(
M, layer.top_k, dtype=torch.float32, device=x.device
)
selected_experts = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)
token_expert_indices = torch.empty(
M, layer.top_k, dtype=torch.int32, device=x.device
)

if layer.use_grouped_topk:
routing_weights, selected_experts = torch.ops._moe_C.fused_grouped_topk(
x,
router_logits,
layer.top_k,
layer.renormalize,
n_expert_group=layer.num_expert_group,
n_topk_group=layer.topk_group,
scoring_func=layer.scoring_func,
routed_scaling_factor=layer.routed_scaling_factor,
bias=layer.e_score_correction_bias,
)
else:
torch.ops._moe_C.topk_softmax(
routing_weights,
selected_experts,
token_expert_indices,
router_logits,
layer.renormalize,
layer.e_score_correction_bias,
)

return xpu_fused_moe(
hidden_states=x,
w13=layer.w13_weight,
w13_bias=layer.w13_bias if self.moe.has_bias else None,
w13_scales=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
w2_scales=layer.w2_weight_scale,
topk_weights=routing_weights,
topk_ids=selected_experts,
n_experts_per_token=layer.top_k,
activation=layer.activation.value,
num_experts=layer.local_num_experts,
is_mxfp4=True,
)
Loading