Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def get_config() -> dict[str, Any] | None:
from vllm.model_executor.layers.fused_moe.fused_moe import (
GroupedTopk,
TritonExperts,
TritonWNA16Experts,
fused_experts,
fused_topk,
get_config_file_name,
Expand All @@ -103,6 +104,7 @@ def get_config() -> dict[str, Any] | None:
"CutlassBatchedExpertsFp8",
"CutlassExpertsW4A8Fp8",
"TritonExperts",
"TritonWNA16Experts",
"BatchedTritonExperts",
"DeepGemmExperts",
"BatchedDeepGemmExperts",
Expand Down
146 changes: 144 additions & 2 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,11 +617,11 @@ def invoke_fused_moe_wna16_triton_kernel(
compute_type: tl.dtype,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: list[int],
block_shape: list[int] | None,
):
assert B_scale is not None and B_scale.ndim == 3
assert B_zp is None or B_zp.ndim == 3
assert block_shape is None or block_shape[0] == 0
assert block_shape is not None and block_shape[0] == 0

M = A.size(0)
num_tokens = M * top_k
Expand Down Expand Up @@ -2440,6 +2440,148 @@ def moe_sum(self, input: torch.Tensor, output: torch.Tensor) -> None:
ops.moe_sum(input, output)


class TritonWNA16Experts(TritonExperts):
def __init__(
self,
quant_config: FusedMoEQuantConfig,
):
super().__init__(quant_config)

def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
# Check constraints.
if self.quant_config.use_int4_w4a16:
assert hidden_states.size(-1) // 2 == w1.size(2), "Hidden size mismatch"
else:
assert hidden_states.size(-1) == w1.size(2), (
f"Hidden size mismatch {hidden_states.size(-1)} != {w1.size(2)}"
)

assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert hidden_states.dim() == 2
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
]

E, num_tokens, N, K, top_k_num = self.moe_problem_size(
hidden_states, w1, w2, topk_ids
)

if global_num_experts == -1:
global_num_experts = E

config = try_get_optimal_moe_config(
w1.size(),
w2.size(),
top_k_num,
self.quant_config.config_name(hidden_states.dtype),
num_tokens,
block_shape=self.block_shape,
)

if hidden_states.dtype == torch.bfloat16:
compute_type = tl.bfloat16
elif hidden_states.dtype == torch.float16:
compute_type = tl.float16
elif hidden_states.dtype == torch.float32:
compute_type = tl.float32
elif (
hidden_states.dtype == torch.float8_e4m3fn
or hidden_states.dtype == torch.float8_e4m3fnuz
):
compute_type = tl.bfloat16
else:
raise ValueError(f"Unsupported compute_type: {hidden_states.dtype}")

# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, N // 2)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], global_num_experts, expert_map
)

invoke_fused_moe_wna16_triton_kernel(
hidden_states,
w1,
intermediate_cache1,
self.w1_scale,
self.quant_config.w1_zp,
None, # topk_weights
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False, # mul_routed_weights
top_k_num,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)

self.activation(
activation, intermediate_cache2, intermediate_cache1.view(-1, N)
)

a2q_scale: torch.Tensor | None = None

qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
intermediate_cache2,
a2_scale,
self.quant_dtype,
self.per_act_token_quant,
self.block_shape,
)

invoke_fused_moe_wna16_triton_kernel(
qintermediate_cache2,
w2,
intermediate_cache3,
self.w2_scale,
self.quant_config.w2_zp,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
use_int8_w8a16=self.quant_config.use_int8_w8a16,
use_int4_w4a16=self.quant_config.use_int4_w4a16,
block_shape=self.block_shape,
)

# separate function is required for MoE + LoRA
self.moe_sum(intermediate_cache3, output)


def modular_triton_fused_moe(
quant_config: FusedMoEQuantConfig, shared_experts: torch.nn.Module | None = None
) -> mk.FusedMoEModularKernel:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2006,11 +2006,11 @@ def select_gemm_impl(
from vllm.triton_utils import HAS_TRITON

if HAS_TRITON:
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe import TritonWNA16Experts

layer.w13_weight = layer.w13_weight_packed
layer.w2_weight = layer.w2_weight_packed
return TritonExperts(quant_config=self.moe_quant_config)
return TritonWNA16Experts(quant_config=self.moe_quant_config)
else:
raise NotImplementedError(
"TritonExperts requires Triton. "
Expand Down