Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
Expand All @@ -12,6 +13,7 @@
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP,
)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.import_utils import has_triton_kernels

Expand Down Expand Up @@ -88,14 +90,17 @@ def triton_kernel_moe_forward(
gating_output, topk, sm_first=not renormalize
)

output = torch.empty_like(hidden_states)

return triton_kernel_fused_experts(
None,
output,
hidden_states,
w1,
w2,
routing_data,
gather_idx,
scatter_idx,
topk=topk,
activation=activation,
quant_config=quant_config,
apply_router_weight_on_input=apply_router_weight_on_input,
Expand All @@ -113,13 +118,15 @@ def triton_kernel_fused_experts(
routing_data, # RoutingData
gather_indx, # GatherIndx
scatter_indx, # ScatterIndx
topk: int,
activation: str = "silu",
quant_config: FusedMoEQuantConfig | None = None,
swiglu_alpha: float = 1.702,
swiglu_limit: float = 7.0,
apply_router_weight_on_input: bool = False,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
intermediate_cache: torch.Tensor | None = None,
a1q_scale: torch.Tensor | None = None,
) -> torch.Tensor:
if quant_config is None:
Expand All @@ -131,22 +138,38 @@ def triton_kernel_fused_experts(
assert quant_config.w2_bias is None or quant_config.w2_bias.dtype == torch.float32

# Shape check, only check non-mxfp4
assert hidden_states.ndim == 2
assert hidden_states.shape[-1] == w1.shape[-2]
assert w2.shape[-1] == w1.shape[1]

batch_dim = 1
M, K = hidden_states.shape[-2:]
E, _, N = w1.shape

if global_num_experts == -1:
global_num_experts = E

if intermediate_cache is None:
intermediate_cache = torch.empty(
(batch_dim, M * topk, N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)

# Add batch_dim to output buffer because matmul_ogs expects 3D output
intermediate_cache = _resize_cache(
intermediate_cache, (batch_dim, M * topk, N // 2)
)
output_tensor = _resize_cache(output_tensor, (batch_dim, M, K))

act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)
gammas = routing_data.gate_scal if routing_data else None

intermediate_cache1 = matmul_ogs(
matmul_ogs(
hidden_states,
w1,
quant_config.w1_bias,
Expand All @@ -155,10 +178,11 @@ def triton_kernel_fused_experts(
precision_config=quant_config.w1_precision,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
y=intermediate_cache,
)

intermediate_cache3 = matmul_ogs(
intermediate_cache1,
matmul_ogs(
intermediate_cache.view(M * topk, N // 2),
w2,
quant_config.w2_bias,
routing_data,
Expand All @@ -167,7 +191,8 @@ def triton_kernel_fused_experts(
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
return intermediate_cache3
output_tensor = output_tensor.view(M, K)
return output_tensor


def make_routing_data(
Expand Down Expand Up @@ -221,6 +246,42 @@ def __init__(self, quant_config: FusedMoEQuantConfig):
def supports_expert_map(self) -> bool:
return True

def moe_problem_size(
self,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
) -> tuple[int, int, int, int, int]:
"""
Extract the MoE problem size from the given tensor arguments:
- a: The hidden states, input to the MoE layer.
- w1: The first set of expert weights.
- w2: The second set of expert weights.
- topk_ids: The topk ids.
Note: extracting the problem shape from the weight and activation
tensors is not obvious. It needs to be done this way specifically
due to subtle issues with particular kernels, e.g. the int4 kernels
divide the trailing dimension by two, so it's not "correct" to
extract N or K from the trailing dimension of w1 or w2. Similarly,
some kernels transpose the weights, so this needs to be kept in mind.
Note: This implementation covers most cases. However, if experts
require a specialized implementation, like MarlinExperts, they are free
to override this function.
"""
assert w1.dim() == 3 and w2.dim() == 3
E, _, N = w1.size()
K = a1.size(-1)

assert a1.dim() == 2
assert topk_ids.size(0) == a1.size(0), f"{topk_ids.size(0)} != {a1.size(0)}"
M = a1.size(0)

assert topk_ids.dim() == 2
topk = topk_ids.size(1)

return E, M, N, K, topk

def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
# Weight application and reduction happens in the fused_experts kernel.
return TopKWeightAndReduceNoOP()
Expand Down Expand Up @@ -263,8 +324,8 @@ def workspace_shapes(
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# workspace are allocated inside the kernel
workspace1 = (M, K)
workspace2 = (0, 0)
workspace1 = (0, 0)
workspace2 = (M * topk, N // 2)
output = (M, K)
return (workspace1, workspace2, output)

Expand Down Expand Up @@ -297,20 +358,21 @@ def apply(
topk_ids, topk_weights, local_num_experts
)

experts_output = triton_kernel_fused_experts(
None,
topk = topk_ids.size(1)
triton_kernel_fused_experts(
output,
hidden_states,
w1,
w2,
routing_data,
gather_indx,
scatter_indx,
topk=topk,
activation=activation,
quant_config=self.quant_config,
apply_router_weight_on_input=False,
global_num_experts=local_num_experts,
expert_map=None, # applied already
intermediate_cache=workspace2,
a1q_scale=a1q_scale,
)

output.copy_(experts_output, non_blocking=True)