Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions vllm/model_executor/layers/fused_moe/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MoEActivation(Enum):
# and produce output of shape [..., d]
SILU = "silu"
GELU = "gelu"
GELU_TANH = "gelu_tanh"
RELU2 = "relu2"
SWIGLUOAI = "swigluoai"
SWIGLUSTEP = "swiglustep"
Expand All @@ -24,6 +25,7 @@ class MoEActivation(Enum):
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
SILU_NO_MUL = "silu_no_mul"
GELU_NO_MUL = "gelu_no_mul"
GELU_TANH_NO_MUL = "gelu_tanh_no_mul"
RELU2_NO_MUL = "relu2_no_mul"

@property
Expand Down Expand Up @@ -53,6 +55,8 @@ def without_mul(self) -> "MoEActivation":
@classmethod
def from_str(cls, s: str) -> "MoEActivation":
"""Parse from string for backward compatibility."""
if s == "gelu_pytorch_tanh":
s = cls.GELU_TANH.value
for member in cls:
if member.value == s:
return member
Expand All @@ -64,17 +68,20 @@ def from_str(cls, s: str) -> "MoEActivation":
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
MoEActivation.SILU: "silu_and_mul",
MoEActivation.GELU: "gelu_and_mul",
MoEActivation.GELU_TANH: "gelu_tanh_and_mul",
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
MoEActivation.RELU2: "relu2",
MoEActivation.SILU_NO_MUL: "silu_and_mul",
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
MoEActivation.GELU_TANH_NO_MUL: "gelu_tanh_and_mul",
MoEActivation.RELU2_NO_MUL: "relu2",
}

_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH: MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
}

Expand Down Expand Up @@ -115,6 +122,12 @@ def apply_moe_activation(
torch.ops._C.silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
torch.ops._C.gelu_and_mul(output, input)
elif activation == MoEActivation.GELU_TANH:
if hasattr(torch.ops._C, "gelu_tanh_and_mul"):
torch.ops._C.gelu_tanh_and_mul(output, input)
else:
gate, up = input.chunk(2, dim=-1)
output.copy_(F.gelu(gate, approximate="tanh") * up)
elif activation == MoEActivation.SWIGLUOAI:
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
Expand All @@ -127,6 +140,8 @@ def apply_moe_activation(
output.copy_(F.silu(input))
elif activation == MoEActivation.GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == MoEActivation.GELU_TANH_NO_MUL:
output.copy_(F.gelu(input, approximate="tanh"))
elif activation == MoEActivation.RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def _gelu_and_mul(
MoEActivation.SILU: SiluAndMul.forward_native,
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
MoEActivation.GELU_TANH: (
lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh")
* x[..., x.shape[-1] // 2 :]
),
}


Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
]

Expand Down Expand Up @@ -709,10 +710,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,9 +944,11 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,10 +1940,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
157 changes: 157 additions & 0 deletions vllm/model_executor/layers/quantization/inc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

import regex as re
import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
Expand All @@ -18,9 +20,17 @@
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer,
)
Comment on lines +23 to +25
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The get_tensor_model_parallel_rank function is required to correctly calculate the shard_offset in INCGPTQRowParallelTailLinearMethod.create_weights. It should be imported from vllm.distributed.

Suggested change
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer,
)
from vllm.distributed import (
get_tensor_model_parallel_rank,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supports_layer,
)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok fixed

from vllm.model_executor.layers.quantization.utils.quant_utils import (
unpack_quantized_values_into_int32,
)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedColumnParameter,
PackedvLLMParameter,
RowvLLMParameter,
)
Expand Down Expand Up @@ -341,6 +351,22 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
group_size,
sym,
)
if (
isinstance(layer, LinearBase)
and group_size > 0
and getattr(layer, "input_size_per_partition", layer.input_size)
% group_size
!= 0
):
# Gemma4 AutoRound row-parallel linears can produce TP shards that
# straddle a GPTQ group boundary. Fall back to a correctness-first
# path in that case instead of using Marlin/GPTQ kernels that
# assume group-aligned input shards.
return INCGPTQRowParallelTailLinearMethod(
weight_bits=weight_bits,
group_size=group_size,
sym=sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
Expand All @@ -353,6 +379,10 @@ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
elif isinstance(layer, LinearBase):
use_marlin = use_marlin and check_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
Expand Down Expand Up @@ -625,3 +655,130 @@ def apply(
None, # g_idx not needed: desc_act is always False for INC models
)
return out.reshape(out_shape)


class INCGPTQRowParallelTailLinearMethod(LinearMethodBase):
"""Fallback for row-parallel GPTQ-family linears with group-tail shards."""

def __init__(self, weight_bits: int, group_size: int, sym: bool):
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.pack_factor = 32 // weight_bits
self.weight_type = {
4: scalar_types.uint4b8,
8: scalar_types.uint8b128,
}[weight_bits]

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
del output_size # Unused.
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
full_num_groups = (input_size + self.group_size - 1) // self.group_size

qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition // self.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
)
scales = ChannelQuantScaleParameter(
data=torch.empty(
full_num_groups,
output_size_per_partition,
dtype=params_dtype,
),
output_dim=1,
weight_loader=weight_loader,
)
qzeros = PackedColumnParameter(
data=torch.empty(
full_num_groups,
output_size_per_partition // self.pack_factor,
dtype=torch.int32,
),
output_dim=1,
packed_dim=1,
packed_factor=self.pack_factor,
weight_loader=weight_loader,
)

layer.register_parameter("qweight", qweight)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)

shard_width = getattr(
layer, "input_size_per_partition", input_size_per_partition
)
shard_offset = get_tensor_model_parallel_rank() * shard_width
g_idx = (
torch.arange(input_size_per_partition, dtype=torch.int32) + shard_offset
) // self.group_size
layer.register_parameter("g_idx", Parameter(g_idx, requires_grad=False))
layer._inc_tail_dequant_weight = None

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.sym:
# The tail-shard fallback dequantizes weights on demand and handles
# the symmetric zero point via weight_type.bias in
# _get_dequantized_weight(), so the large packed qzeros tensor is
# replaced with a tiny placeholder after loading.
layer.qzeros = Parameter(
torch.tensor([8], dtype=torch.int8, device=layer.qweight.device),
requires_grad=False,
)
else:
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)

def _get_dequantized_weight(self, layer: torch.nn.Module) -> torch.Tensor:
cached = layer._inc_tail_dequant_weight
if cached is not None:
return cached

if not self.sym:
raise NotImplementedError(
"INCGPTQRowParallelTailLinearMethod currently supports only "
"symmetric checkpoints."
)

qweight = unpack_quantized_values_into_int32(
layer.qweight.data, self.weight_type, packed_dim=0
).to(torch.float32)
qweight = qweight - float(self.weight_type.bias)

g_idx = layer.g_idx.data.to(torch.long)
scales = layer.scales.data.to(torch.float32)
dequant = qweight * scales.index_select(0, g_idx)
weight = dequant.t().contiguous()
# Cache the dequantized tail-shard weight after the first fallback use.
layer._inc_tail_dequant_weight = weight
return weight

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
out_shape = x.shape[:-1] + (layer.qweight.shape[1],)
x_2d = x.reshape(-1, x.shape[-1]).to(torch.float32)
bias_2d = bias.to(torch.float32) if bias is not None else None
output = F.linear(x_2d, self._get_dequantized_weight(layer), bias_2d)
return output.to(x.dtype).reshape(out_shape)
6 changes: 1 addition & 5 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
Expand Down Expand Up @@ -372,17 +371,14 @@ def apply(
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)

return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
Expand Down
Loading
Loading