Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 0 additions & 129 deletions tests/kernels/moe/test_triton_moe_no_act_mul.py

This file was deleted.

9 changes: 0 additions & 9 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,6 @@ class FusedMoEQuantConfig:
_w1: FusedMoEQuantDesc
_w2: FusedMoEQuantDesc

# Whether activation is fused with gate multiplication (SwiGLU-style).
# When True: intermediate_size = N // 2 (gate and up are combined)
# When False: intermediate_size = N (no gate multiplication)
is_act_and_mul: bool = True

def __post_init__(self):
assert not self.per_act_token_quant or self.block_shape is None, (
"illegal quantization"
Expand Down Expand Up @@ -444,7 +439,6 @@ def make(
w1_zp: torch.Tensor | None = None,
w2_zp: torch.Tensor | None = None,
weight_dtype: torch.dtype | str | None = None,
is_act_and_mul: bool = True,
) -> "FusedMoEQuantConfig":
"""
General builder function for a FusedMoEQuantConfig.
Expand Down Expand Up @@ -504,7 +498,6 @@ def make(
_w2=FusedMoEQuantDesc(
weight_dtype, w_shape, w2_scale, g2_alphas, w2_zp, w2_bias
),
is_act_and_mul=is_act_and_mul,
)
assert quant_config.per_act_token_quant == per_act_token_quant
assert quant_config.per_out_ch_quant == per_out_ch_quant
Expand Down Expand Up @@ -836,7 +829,6 @@ def awq_marlin_moe_quant_config(
def biased_moe_quant_config(
w1_bias: torch.Tensor | None,
w2_bias: torch.Tensor | None,
is_act_and_mul: bool = True,
) -> FusedMoEQuantConfig:
"""
Construct a quant config for unquantized activations with biases.
Expand All @@ -846,7 +838,6 @@ def biased_moe_quant_config(
_a2=FusedMoEQuantDesc(),
_w1=FusedMoEQuantDesc(bias=w1_bias),
_w2=FusedMoEQuantDesc(bias=w2_bias),
is_act_and_mul=is_act_and_mul,
)


Expand Down
13 changes: 3 additions & 10 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,11 +871,8 @@ def workspace_shapes(
num_dp = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = self.max_num_tokens
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
# For non-fused activations: N = intermediate, after act = N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dp, intermediate_size)
workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2))
output = (num_experts, max_num_tokens * num_dp, K)
return (workspace13, workspace2, output)

Expand Down Expand Up @@ -950,11 +947,7 @@ def apply(
# We can reuse the memory between these because by the time we need
# cache3, we're done with cache1
intermediate_cache1 = _resize_cache(workspace13, (E, max_num_tokens, N))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
intermediate_cache2 = _resize_cache(
workspace2, (E, max_num_tokens, intermediate_size)
)
intermediate_cache2 = _resize_cache(workspace2, (E, max_num_tokens, N // 2))

# TODO(bnell): should this be done for any quantized type?
if self.quant_config.use_fp8_w8a8:
Expand Down Expand Up @@ -985,7 +978,7 @@ def apply(
# TODO (bnell): use triton utility from batched deep gemm.
self.activation(
activation,
intermediate_cache2.view(-1, intermediate_size),
intermediate_cache2.view(-1, N // 2),
intermediate_cache1.view(-1, N),
)

Expand Down
9 changes: 2 additions & 7 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,10 +2296,7 @@ def workspace_shapes(
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# For fused activations (SwiGLU): N = 2 * intermediate, after act = N/2
# For non-fused activations: N = intermediate, after act = N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
workspace1 = (M, topk, max(intermediate_size, K))
workspace1 = (M, topk, max(N // 2, K))
workspace2 = (M, topk, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
Expand Down Expand Up @@ -2374,10 +2371,8 @@ def apply(

# Note that the output tensor might be in workspace1
intermediate_cache1 = _resize_cache(workspace2, (num_tokens, top_k_num, N))
# For fused activations (SwiGLU): output is N/2, for non-fused: output is N
intermediate_size = N // 2 if self.quant_config.is_act_and_mul else N
intermediate_cache2 = _resize_cache(
workspace13, (num_tokens * top_k_num, intermediate_size)
workspace13, (num_tokens * top_k_num, N // 2)
)
intermediate_cache3 = _resize_cache(workspace2, (num_tokens, top_k_num, K))

Expand Down
10 changes: 2 additions & 8 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,15 +600,9 @@ def _get_quant_method() -> FusedMoEMethodBase:
"is_act_and_mul=False is supported only for unquantized "
", ModelOpt FP8, and ModelOpt NvFp4 checkpoints"
)
# ROCm without AITER MoE uses Triton which supports
# is_act_and_mul=False via standard PyTorch ops (F.silu, F.gelu)
rocm_without_aiter_moe = (
current_platform.is_rocm() and not rocm_aiter_ops.is_fused_moe_enabled()
)
if not current_platform.is_cuda() and not rocm_without_aiter_moe:
if not current_platform.is_cuda():
raise NotImplementedError(
"is_act_and_mul=False is supported only for CUDA, or ROCm "
"(when AITER MoE is disabled) for now"
"is_act_and_mul=False is supported only for CUDA for now"
)
Comment on lines +603 to 606
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The surrounding if not self.moe_config.is_act_and_mul: block (starting at line 584) is now dead code because this PR reverts the is_act_and_mul=False feature. With this revert, self.moe_config.is_act_and_mul will always be True. To complete the revert and improve code clarity, this entire if block (lines 584-607) should be removed. As a follow-up, the is_act_and_mul attribute should also be removed from FusedMoEConfig in vllm/model_executor/layers/fused_moe/config.py.


if self.enable_eplb and not self.quant_method.supports_eplb:
Expand Down
25 changes: 2 additions & 23 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from math import prod, sqrt
from math import prod
from typing import final

import torch
Expand Down Expand Up @@ -575,35 +575,14 @@ def workspace_shapes(
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
# Fused activations (SwiGLU-style): output is half the size of input
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.swigluoai_and_mul(output, input)
# Non-fused activations (is_act_and_mul=False): output same size as input
elif activation == "silu_no_mul":
assert output.size(-1) == input.size(-1)
# Use out= argument to avoid intermediate tensor
torch.sigmoid(input, out=output)
output.mul_(input)
elif activation == "gelu_no_mul":
assert output.size(-1) == input.size(-1)
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# Use out= and in-place ops to avoid intermediate tensors
output.copy_(input).div_(sqrt(2))
torch.erf(output, out=output)
output.add_(1).mul_(input).mul_(0.5)
elif activation == "relu2_no_mul":
assert output.size(-1) == input.size(-1)
# ReLU²: clamp has out=, then in-place square
torch.clamp(input, min=0, out=output)
output.square_()
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,7 @@ def get_fused_moe_quant_config(self, layer: torch.nn.Module) -> FusedMoEQuantCon
return biased_moe_quant_config(
layer.w13_bias,
layer.w2_bias,
is_act_and_mul=self.moe.is_act_and_mul,
)
elif not self.moe.is_act_and_mul:
# Create a config with is_act_and_mul=False since
# FUSED_MOE_UNQUANTIZED_CONFIG has is_act_and_mul=True
return FusedMoEQuantConfig.make(is_act_and_mul=False)
else:
return FUSED_MOE_UNQUANTIZED_CONFIG

Expand Down