Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,11 @@
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
if quant_dtype is None and isinstance(quant_config, Fp8Config):
quant_dtype = torch.float8_e4m3fn

from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4Config
if (quant_dtype is None and isinstance(quant_config, Mxfp4Config) and envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8):
quant_dtype = "mxfp8"

Check failure on line 452 in vllm/model_executor/layers/fused_moe/config.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/model_executor/layers/fused_moe/config.py:452:81: E501 Line too long (81 > 80)
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4Config)
if quant_dtype is None and isinstance(quant_config,
Expand Down
6 changes: 5 additions & 1 deletion vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,11 @@

class FusedMoEMethodBase(QuantizeMethodBase):

moe: FusedMoEConfig
def __init__(self, moe: FusedMoEConfig):
super().__init__()
self.moe = moe
self.fused_experts: Optional[Callable] = None
self.topk_indices_dtype = None

@abstractmethod
def create_weights(self, layer: torch.nn.Module, num_experts: int,
Expand Down Expand Up @@ -251,7 +255,7 @@
"""MoE method without quantization."""

def __init__(self, moe: FusedMoEConfig):
super().__init__()

Check failure on line 258 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing positional argument "moe" in call to "__init__" of "FusedMoEMethodBase" [call-arg]

Check failure on line 258 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing positional argument "moe" in call to "__init__" of "FusedMoEMethodBase" [call-arg]

Check failure on line 258 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

Missing positional argument "moe" in call to "__init__" of "FusedMoEMethodBase" [call-arg]
self.fused_experts = fused_experts # type: ignore
self.topk_indices_dtype = None
self.moe = moe
Expand Down Expand Up @@ -498,7 +502,7 @@
):
if enable_eplb is not False or expert_load_view is not None or \
logical_to_physical_map is not None or \
logical_replica_count is not None:

Check failure on line 505 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"None" not callable [misc]

Check failure on line 505 in vllm/model_executor/layers/fused_moe/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"None" not callable [misc]
raise NotImplementedError("Expert load balancing is not supported "
"for CPU.")
return layer.cpu_fused_moe(
Expand Down
193 changes: 193 additions & 0 deletions vllm/model_executor/layers/fused_moe/trtllm_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional

import torch

import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.model_executor.layers.fused_moe.utils import extract_required_args
from vllm.utils import next_power_of_2

if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
from flashinfer import trtllm_fp4_block_scale_routed_moe


class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):

def __init__(self, moe: FusedMoEConfig):
super().__init__(moe.quant_config)
self.moe = moe

@property
def activation_formats(
self
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
return (mk.FusedMoEActivationFormat.Standard,
mk.FusedMoEActivationFormat.Standard)

def supports_chunking(self) -> bool:
return True

def supports_expert_map(self) -> bool:
return False

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()

def workspace_shapes(
self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]:
# The workspaces for this implementation are managed by flashinfer.
# TODO(varun) : workspace1 is could be used as the output tensor. This
# is error-prone. Allow the `workspace_shapes` to return None workspaces
workspace1 = (M, K)
workspace2 = (1, 1) # (1, 1) as we cant return None.
output = (M, K)
return (workspace1, workspace2, output, a.dtype)

def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int,
local_num_experts: int):
# Number of tokens in the input tensor.
num_tokens = x.shape[0]
# Factor to account for the imbalance of the experts.
# factor equals to the
# max_real_num_tokens_per_expert / perfect_num_tokens_per_expert
# 1.0 means perfect expert distribution.
# > 1.0 means some experts have more tokens than the perfect
# distribution.
# < 1.0 does not make sense.
imbalance_factor = 1.3
# Calculate the number of tokens per expert assuming perfect
# distribution.
num_tokens_per_expert = (num_tokens * top_k) // local_num_experts
# Apply the imbalance factor.
num_tokens_per_expert = int(num_tokens_per_expert * imbalance_factor)
# And pad the number to the next power of 2.
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
# Cap to 8-64 tokens per CTA tile as it's the range supported by the
# kernel.
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)

return tile_tokens_dim

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: Optional[torch.Tensor],
w1_scale: Optional[torch.Tensor],
w2_scale: Optional[torch.Tensor],
w1_zp: Optional[torch.Tensor],
w2_zp: Optional[torch.Tensor],
a1q_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
apply_router_weight_on_input: bool,
extra_expert_args: Optional[dict[str, Any]],
):
topk = topk_ids.size(-1)
local_num_experts = w1.size(0)
intermediate_size = w2.size(1)
local_expert_offset = self.moe.ep_rank * local_num_experts

x_quant = hidden_states
x_scale = a1q_scale
if x_scale is not None:
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)

# Extract extra args
required_keys = [
'gemm1_alpha', 'gemm1_beta', 'gemm1_clamp_limit', "w1_bias",
"w2_bias"
]
gemm1_alpha, gemm1_beta, gemm1_clamp_limit, w1_bias, w2_bias = (
extract_required_args(extra_expert_args, required_keys))

packed_tensor = (topk_ids.to(torch.int32) << 16) | topk_weights.to(
torch.bfloat16).view(torch.int16)

assert w1_scale is not None
assert w2_scale is not None
kwargs = {
"topk_ids":
packed_tensor,
"routing_bias":
None,
"hidden_states":
x_quant,
"hidden_states_scale":
x_scale,
"gemm1_weights":
w1,
"gemm1_weights_scale":
w1_scale,
"gemm1_bias":
w1_bias,
"gemm1_alpha":
gemm1_alpha,
"gemm1_beta":
gemm1_beta,
"gemm1_clamp_limit":
gemm1_clamp_limit,
"gemm2_weights":
w2,
"gemm2_weights_scale":
w2_scale,
"gemm2_bias":
w2_bias,
"output1_scale_scalar":
None,
"output1_scale_gate_scalar":
None,
"output2_scale_scalar":
None,
"num_experts":
global_num_experts,
"top_k":
topk,
"n_group":
None,
"topk_group":
None,
"intermediate_size":
intermediate_size,
"local_expert_offset":
local_expert_offset,
"local_num_experts":
local_num_experts,
"routed_scaling_factor":
None,
"tile_tokens_dim":
self._get_tile_tokens_dim(x_quant, topk, local_num_experts),
"routing_method_type":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

routing_method_type is hardcoded to renormalize. Maybe add assertion above to make sure it's not using a different routing method.

1,
"do_finalize":
True,
"output":
output,
}

trtllm_fp4_block_scale_routed_moe(**kwargs)
return output
13 changes: 13 additions & 0 deletions vllm/model_executor/layers/fused_moe/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
per_token_group_quant_int8, per_token_quant_int8)
from vllm.model_executor.layers.quantization.utils.mxfp4_utils import (
quant_dequant_mxfp4)
from vllm.model_executor.layers.quantization.utils.mxfp8_utils import (
mxfp8_quantize)
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv
Expand Down Expand Up @@ -176,6 +178,15 @@ def _mxfp4_quantize(

return A, None

def _mxfp8_quantize(A: torch.Tensor,
A_scale: Optional[torch.Tensor],
per_act_token_quant: bool,
block_shape: Optional[list[int]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
assert A_scale is None
assert not per_act_token_quant
assert block_shape is None
return mxfp8_quantize(A)

def moe_kernel_quantize_input(
A: torch.Tensor,
Expand All @@ -195,6 +206,8 @@ def moe_kernel_quantize_input(
is_sf_swizzled_layout=is_fp4_scale_swizzled)
elif quant_dtype == "mxfp4":
return _mxfp4_quantize(A, A_scale, per_act_token_quant, block_shape)
elif quant_dtype == "mxfp8":
return _mxfp8_quantize(A, A_scale, per_act_token_quant, block_shape)
else:
return A, A_scale

Expand Down
Loading
Loading