Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/design/moe_kernel_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ th {
| deepep_high_throughput | standard | fp8 | G(128),A,T<sup>2</sup> | Y | Y | [`DeepEPHTPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ht.DeepEPHTPrepareAndFinalize] |
| deepep_low_latency | batched | fp8 | G(128),A,T<sup>3</sup> | Y | Y | [`DeepEPLLPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.deepep_ll.DeepEPLLPrepareAndFinalize] |
| flashinfer_nvlink_two_sided | standard | nvfp4,fp8 | G,A,T | N | N | [`FlashInferNVLinkTwoSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_two_sided.FlashInferNVLinkTwoSidedPrepareAndFinalize] |
| flashinfer_nvlink_one_sided | standard | nvfp4 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |
| flashinfer_nvlink_one_sided | standard | nvfp4,bf16,mxfp8 | G,A,T | N | N | [`FlashInferNVLinkOneSidedPrepareAndFinalize`][vllm.model_executor.layers.fused_moe.prepare_finalize.flashinfer_nvlink_one_sided.FlashInferNVLinkOneSidedPrepareAndFinalize] |

!!! info "Table key"
1. All types: mxfp4, nvfp4, int4, int8, fp8
Expand Down
10 changes: 8 additions & 2 deletions vllm/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,8 @@ def initialize(
top_k: int,
num_experts: int,
hidden_size: int,
dispatch_dtype_bytes_per_elem: int = 0,
dispatch_scale_bytes_per_token: int = 0,
):
"""Initialize the MoeAlltoAll workspace."""
if self.initialized:
Expand Down Expand Up @@ -607,9 +609,13 @@ def initialize(
ep_config = MnnvlConfig(
comm_backend=CustomCommunicator(self.cpu_group),
)
if dispatch_dtype_bytes_per_elem == 0:
hidden_bytes = hidden_size // 2
else:
hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem
total_dispatch_payload_size_per_token = (
hidden_size // 2 # nvfp4 hidden states
+ hidden_size // 16 # fp8 scaling factors
hidden_bytes
+ dispatch_scale_bytes_per_token
+ top_k * 4 # int32 topks ids
+ top_k * 4 # float32 topk weights
)
Expand Down
30 changes: 22 additions & 8 deletions vllm/model_executor/layers/fused_moe/all2all_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,23 +228,37 @@ def maybe_make_prepare_finalize(

elif moe.use_fi_nvl_one_sided_kernels:
assert quant_config is not None
if quant_config.quant_dtype != "nvfp4":
raise ValueError(
"The 'flashinfer_nvlink_one_sided' all2all backend only "
"supports nvfp4 activation quantization, but got "
f"quant_dtype={quant_config.quant_dtype!r}. Use a different "
"all2all backend (e.g. 'flashinfer_nvlink_two_sided' or "
"'allgather_reducescatter') for non-nvfp4 models."
)
max_num_tokens = (
get_current_vllm_config().scheduler_config.max_num_batched_tokens
)
if quant_config.quant_dtype is None:
dispatch_dtype_bytes_per_elem = 2
dispatch_scale_bytes_per_token = 0
elif quant_config.quant_dtype == "nvfp4":
dispatch_dtype_bytes_per_elem = 0
dispatch_scale_bytes_per_token = moe.hidden_dim // 16
elif quant_config.quant_dtype == "mxfp8":
dispatch_dtype_bytes_per_elem = 1
align = quant_config.mx_alignment
if align > 0:
padded_k = ((moe.hidden_dim + align - 1) // align) * align
else:
padded_k = moe.hidden_dim
dispatch_scale_bytes_per_token = padded_k // 32
else:
raise NotImplementedError(
"flashinfer_nvlink_one_sided dispatch supports nvfp4, mxfp8, "
"and bf16 (quant_dtype=None) today; got "
f"quant_dtype={quant_config.quant_dtype!r}"
)
prepare_finalize = FlashInferNVLinkOneSidedPrepareAndFinalize(
max_num_tokens=max_num_tokens,
top_k=moe.experts_per_token,
num_experts=moe.num_experts,
hidden_size=moe.hidden_dim,
num_dispatchers=all2all_manager.world_size,
dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem,
dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token,
)

elif moe.use_ag_rs_all2all_kernels and allow_new_interface:
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ class FusedMoEQuantConfig:
gemm1_beta: float | None = None
gemm1_clamp_limit: float | None = None

mx_alignment: int = 0

def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
Expand Down Expand Up @@ -712,6 +714,7 @@ def mxfp4_mxfp8_moe_quant_config(
gemm1_alpha: float | None = None,
gemm1_beta: float | None = None,
gemm1_clamp_limit: float | None = None,
mx_alignment: int = 0,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for mxfp4 activations and mxfp4 weights.
Expand All @@ -724,6 +727,7 @@ def mxfp4_mxfp8_moe_quant_config(
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=gemm1_clamp_limit,
mx_alignment=mx_alignment,
)


Expand Down
53 changes: 17 additions & 36 deletions vllm/model_executor/layers/fused_moe/experts/trtllm_mxfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def __init__(
moe_config.intermediate_size_per_partition
)
self.hidden_dim = moe_config.hidden_dim
self.hidden_dim_unpadded = (
moe_config.hidden_dim_unpadded or moe_config.hidden_dim
)
self.local_num_experts = moe_config.num_local_experts
self.ep_rank = moe_config.moe_parallel_config.ep_rank

Expand Down Expand Up @@ -82,9 +85,6 @@ def __init__(
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)

# P1-5 fix: use public quant_dtype property instead of private _a1
self.use_mxfp8_input = quant_config.quant_dtype == "mxfp8"

@staticmethod
def _supports_current_device() -> bool:
p = current_platform
Expand Down Expand Up @@ -121,8 +121,7 @@ def supports_expert_map(self) -> bool:

@property
def expects_unquantized_inputs(self) -> bool:
# Expert handles MXFP8 quantization internally if needed
return True
return False


class TrtLlmMxfp4ExpertsMonolithic(
Expand Down Expand Up @@ -181,24 +180,19 @@ def apply(
) -> torch.Tensor:
from flashinfer import trtllm_fp4_block_scale_moe

# Handle input quantization
if self.use_mxfp8_input:
from flashinfer import mxfp8_quantize

x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
if a1q_scale is not None:
x_quant = hidden_states
x_scale = a1q_scale.view(torch.float8_e4m3fn)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
x_scale = None

output = torch.empty_like(hidden_states)
output = torch.empty(
*hidden_states.shape[:-1],
self.hidden_dim_unpadded,
dtype=torch.bfloat16,
device=hidden_states.device,
)

from vllm.utils.flashinfer import _is_fi_autotuning, autotune

Expand Down Expand Up @@ -244,10 +238,6 @@ class TrtLlmMxfp4ExpertsModular(TrtLlmMxfp4ExpertsBase, mk.FusedMoEExpertsModula
Moved from trtllm_moe.py.
"""

@property
def expects_unquantized_inputs(self) -> bool:
return True

@staticmethod
def _supports_parallel_config(
moe_parallel_config: FusedMoEParallelConfig,
Expand Down Expand Up @@ -284,7 +274,7 @@ def workspace_shapes(
# The workspaces for this implementation are managed by flashinfer.
workspace1 = (0,)
workspace2 = (0,)
output = (M, K)
output = (M, self.hidden_dim_unpadded)
return (workspace1, workspace2, output)

def apply(
Expand All @@ -310,18 +300,9 @@ def apply(
intermediate_size = self.intermediate_size_per_partition
local_expert_offset = self.moe_config.ep_rank * local_num_experts

# Handle input quantization
if self.use_mxfp8_input:
from flashinfer import mxfp8_quantize

x_quant, x_scale = mxfp8_quantize(
hidden_states,
is_sf_swizzled_layout=False,
alignment=256,
)
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(
*hidden_states.shape[:-1], -1
)
if a1q_scale is not None:
x_quant = hidden_states
x_scale = a1q_scale.view(torch.float8_e4m3fn)
else:
assert hidden_states.dtype == torch.bfloat16
x_quant = hidden_states
Expand Down
17 changes: 12 additions & 5 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,10 +1195,18 @@ def make_mxfp4_moe_quant_config(
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
)
elif mxfp4_backend in (
Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8,
Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8,
):
elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
gemm1_alpha=gemm1_alpha,
gemm1_beta=gemm1_beta,
gemm1_clamp_limit=swiglu_limit,
mx_alignment=256,
)
elif mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_CUTLASS_MXFP4_MXFP8:
return mxfp4_mxfp8_moe_quant_config(
w1_bias=w1_bias,
w2_bias=w2_bias,
Expand Down Expand Up @@ -1250,7 +1258,6 @@ def make_mxfp4_moe_kernel(
"""Create a FusedMoEKernel for the given MXFP4 backend."""
is_monolithic = issubclass(experts_cls, mk.FusedMoEExpertsMonolithic)

# Create Prepare/Finalize.
prepare_finalize = maybe_make_prepare_finalize(
moe=moe_config,
quant_config=moe_quant_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ def __init__(
num_experts: int,
hidden_size: int,
num_dispatchers: int = 1,
dispatch_dtype_bytes_per_elem: int = 0,
dispatch_scale_bytes_per_token: int = 0,
):
super().__init__()
self.max_num_tokens = max_num_tokens
self.top_k = top_k
self.num_experts = num_experts
self.hidden_size = hidden_size
self.num_dispatchers_ = num_dispatchers
self.scale_elems_per_token = dispatch_scale_bytes_per_token

device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
Expand All @@ -49,6 +52,8 @@ def __init__(
top_k=self.top_k,
num_experts=self.num_experts,
hidden_size=self.hidden_size,
dispatch_dtype_bytes_per_elem=dispatch_dtype_bytes_per_elem,
dispatch_scale_bytes_per_token=dispatch_scale_bytes_per_token,
)

@property
Expand Down Expand Up @@ -92,19 +97,24 @@ def prepare(
else a1.shape[0]
)

a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
if defer_input_quant:
a1q, a1q_scale = a1, None
else:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
mx_alignment=quant_config.mx_alignment,
)

payloads = []
payloads.append(a1q)
if a1q_scale is not None:
payloads.append(a1q_scale)
topk_ids_payload_index = len(payloads)
payloads.append(topk_ids)
payloads.append(topk_weights)

Expand All @@ -113,6 +123,8 @@ def prepare(
token_selected_experts=topk_ids,
input_payloads=payloads,
runtime_max_tokens_per_rank=self.runtime_max_tokens_per_rank,
invalid_token_expert_id=-1, # Follow TRTLLM Pattern
expert_id_payload_index=topk_ids_payload_index,
)
if a1q_scale is not None:
a1q_recv, a1q_scale_recv, topk_ids_recv, topk_weights_recv = recv_payloads
Expand All @@ -124,7 +136,8 @@ def prepare(
a1q_scale_recv = a1q_scale_recv.view(-1, a1q_scale_recv.shape[-1])
a1q_scale_recv = a1q_scale_recv.view(torch.uint8)
a1q_scale_recv = nvfp4_block_scale_interleave(a1q_scale_recv)
a1q_scale_recv = a1q_scale_recv.view(-1, self.hidden_size // 16)
assert self.scale_elems_per_token > 0
a1q_scale_recv = a1q_scale_recv.view(-1, self.scale_elems_per_token)
else:
a1q_recv, topk_ids_recv, topk_weights_recv = recv_payloads
a1q_scale_recv = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def flashinfer_alltoall_dispatch(
# the hidden states, breaking the A2A kernel. So, we
# delay the swizzling until after the A2A.
is_fp4_scale_swizzled=False,
mx_alignment=quant_config.mx_alignment,
)

x = MnnvlMoe.mnnvl_moe_alltoallv(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _quantize_and_setup_dispatch(
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=False,
mx_alignment=quant_config.mx_alignment,
)

# Skip gathering scales if we have static quantization
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _quantize_input(
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
is_fp4_scale_swizzled=quant_config.is_nvfp4_scale_swizzled,
mx_alignment=quant_config.mx_alignment,
)

return a1q, a1q_scale
Expand Down
7 changes: 5 additions & 2 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,12 @@ def _mxfp8_e4m3_quantize(
per_act_token_quant: bool,
block_shape: list[int] | None = None,
is_sf_swizzled_layout: bool = False,
mx_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None or block_shape == [1, 32]
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout)
return mxfp8_e4m3_quantize(A, is_sf_swizzled_layout, mx_alignment)


def _mxfp6_e3m2_quantize(
Expand Down Expand Up @@ -258,6 +259,7 @@ def moe_kernel_quantize_input(
is_fp4_scale_swizzled: bool = True,
ocp_mx_scheme: str | None = None,
quantization_emulation: bool = False,
mx_alignment: int = 0,
) -> tuple[torch.Tensor, torch.Tensor | None]:
# Handle OCP MX scheme that requires QDQ (quantize-dequantize) for emulation
if ocp_mx_scheme is not None:
Expand Down Expand Up @@ -319,7 +321,8 @@ def moe_kernel_quantize_input(
A_scale,
per_act_token_quant,
block_shape,
is_sf_swizzled_layout=is_fp4_scale_swizzled,
is_sf_swizzled_layout=False,
mx_alignment=mx_alignment,
)
elif quant_dtype == "mxfp6_e3m2":
if not quantization_emulation:
Expand Down
Loading
Loading