Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 3 additions & 39 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@
disable_inplace,
moe_kernel_quantize_input,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import dequant_mxfp4
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import dequant_mxfp6
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import OCP_MX_Scheme
from vllm.model_executor.utils import maybe_disable_graph_partition
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
Expand Down Expand Up @@ -2017,7 +2014,8 @@ def fused_experts_impl(
# Check constraints.
if use_int4_w4a16:
assert hidden_states.size(1) // 2 == w1.size(2), "Hidden size mismatch"
elif ocp_mx_scheme is not None:
elif ocp_mx_scheme is not None and w1_scale is not None:
# Size checks for packed weights (native MXFP or before dequantization)
if ocp_mx_scheme in {
"w_mxfp4_a_mxfp4",
"w_mxfp4_a_mxfp6_e3m2",
Expand All @@ -2035,6 +2033,7 @@ def fused_experts_impl(
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")
else:
# Normal weights or MXFP emulation (weights already dequantized)
assert hidden_states.size(1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(1)} != {w1.size(2)}"
)
Expand Down Expand Up @@ -2112,41 +2111,6 @@ def fused_experts_impl(
else:
out_hidden_states = torch.empty_like(hidden_states)

if ocp_mx_scheme is not None:
# TODO: On platforms for which `current_platform.supports_mx()` is True
# and for which we have a native OCP mx fused MOE kernel,
# this dequantization step should not be done.
if ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# Weight has to be dequantized for mxfp4 emulation.
w1 = dequant_mxfp4(w1, w1_scale, hidden_states.dtype)
w1_scale = None
w2 = dequant_mxfp4(w2, w2_scale, hidden_states.dtype)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=hidden_states.dtype
)
w2_scale = None
elif ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w1_scale = None
w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=hidden_states.dtype
)
w2_scale = None
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={ocp_mx_scheme}")

for chunk in range((num_tokens // CHUNK_SIZE) + 1):
begin_chunk_idx, end_chunk_idx = (
chunk * CHUNK_SIZE,
Expand Down
101 changes: 90 additions & 11 deletions vllm/model_executor/layers/quantization/quark/quark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
prepare_fp8_moe_layer_for_marlin,
)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
dequant_mxfp4,
)
from vllm.model_executor.layers.quantization.utils.mxfp6_utils import (
dequant_mxfp6,
)
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_BLOCK_SIZE,
OCP_MX_Scheme,
Expand Down Expand Up @@ -635,6 +641,56 @@ def get_packed_dim(self, dim: int, quant_dtype: str):
assert (dim * 3) % 4 == 0
return (dim * 3) // 4

def _dequantize_weights(
Copy link
Copy Markdown
Contributor

@hangy-amd hangy-amd Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I totally agree with the motivation, but I don't think only moving dequantization part is an elegant way. Dequantization is actually part of inference (kernel emulation). Putting dequantization in quant_method would break the purity. quant methods should be only responsible for quantized weights loading and quantization (eg. online quantization). Dequantization should reside in inference part.

I suggest following the design in this PR. We wrappe inference related code such as dequantization and kernels in Kernel class and return the desired Kernel class in quant_method with factory pattern.

cc @robertgshaw2-redhat

Copy link
Copy Markdown
Contributor

@fxmarty-amd fxmarty-amd Jan 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is

class FusedMoEModularKernel(torch.nn.Module):
I was not aware of

Interestingly, https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/mxfp4.py seems not to make use of this, but https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/fp8.py does.

Alternatively there simply needs to be a TODO like

# TODO(rob): convert this to MK.
- there are many places that do not use this abstraction it seems.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I think the dequant part should move into emulation kernel after the kernel refactor.

self,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
dtype: torch.dtype,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Dequantize MXFP4/MXFP6 weights to high precision for emulation.

Args:
w1: Packed w13 weights (uint8)
w2: Packed w2 weights (uint8)
w1_scale: Weight scales for w13
w2_scale: Weight scales for w2
dtype: Target dtype for dequantization (fp16/bf16/fp32)

Returns:
Tuple of (dequantized_w1, dequantized_w2)
"""
if self.ocp_mx_scheme in {
OCP_MX_Scheme.w_mxfp4_a_mxfp4,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e3m2,
OCP_MX_Scheme.w_mxfp4_a_mxfp6_e2m3,
}:
# MXFP4 weights
dequant_w1 = dequant_mxfp4(w1, w1_scale, dtype)
dequant_w2 = dequant_mxfp4(w2, w2_scale, dtype)
elif self.ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e3m2_a_mxfp6_e3m2:
# MXFP6 e3m2 weights
dequant_w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e3m2", float_dtype=dtype
)
dequant_w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e3m2", float_dtype=dtype
)
elif self.ocp_mx_scheme == OCP_MX_Scheme.w_mxfp6_e2m3_a_mxfp6_e2m3:
# MXFP6 e2m3 weights
dequant_w1 = dequant_mxfp6(
w1, w1_scale, quant_dtype="fp6_e2m3", float_dtype=dtype
)
dequant_w2 = dequant_mxfp6(
w2, w2_scale, quant_dtype="fp6_e2m3", float_dtype=dtype
)
else:
raise NotImplementedError(f"Unsupported ocp_mx_scheme={self.ocp_mx_scheme}")

return dequant_w1, dequant_w2

def create_weights(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -736,15 +792,29 @@ def process_weights_after_loading(self, layer):
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
if self.emulate:
# Emulation mode: weights are dequantized in apply(), but intermediate
# activations still need quantization. Set scales to None since
# weights are already dequantized.
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=None,
w2_scale=None,
a1_scale=None,
a2_scale=None,
block_shape=None,
)
else:
return ocp_mx_moe_quant_config(
quant_dtype=self.input_dtype,
weight_dtype=self.weight_dtype,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=None,
a2_scale=None,
block_shape=None,
)

@property
def allow_inplace(self) -> bool:
Expand Down Expand Up @@ -780,10 +850,19 @@ def apply(
else:
from vllm.model_executor.layers.fused_moe import fused_experts

out = fused_experts(
x,
# Dequantize weights for MXFP emulation
dequant_w1, dequant_w2 = self._dequantize_weights(
layer.w13_weight,
layer.w2_weight,
layer.w13_weight_scale,
layer.w2_weight_scale,
x.dtype,
)

out = fused_experts(
x,
dequant_w1,
dequant_w2,
Comment thread
cursor[bot] marked this conversation as resolved.
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
Expand Down