Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion tools/pre_commit/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"tests",
# v0 related
"vllm/lora",
"vllm/model_executor/layers",
]

# TODO(woosuk): Include the code from Megatron and HuggingFace.
Expand Down
27 changes: 15 additions & 12 deletions vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,16 +666,7 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
"gelu": lambda: GELU(),
"gelu_fast": lambda: FastGELU(),
"gelu_new": lambda: NewGELU(),
"gelu_pytorch_tanh": lambda: (
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
),
nn.GELU(approximate="none"),
)[1]
if current_platform.is_rocm()
else nn.GELU(approximate="tanh"),
"gelu_pytorch_tanh": lambda: _get_gelu_pytorch_tanh(),
"relu": lambda: nn.ReLU(),
"relu2": lambda: ReLUSquaredActivation(),
"silu": lambda: nn.SiLU(),
Expand All @@ -687,6 +678,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
)


def _get_gelu_pytorch_tanh() -> nn.Module:
"""Get PyTorch GELU with tanh approximation, with ROCm fallback."""
if current_platform.is_rocm():
# TODO:[ROCm] PyTorch native GELU with tanh is unstable with torch.compile
logger.warning_once(
"[ROCm] PyTorch's native GELU with tanh approximation is unstable. "
"Falling back to GELU(approximate='none')."
)
return nn.GELU(approximate="none")
return nn.GELU(approximate="tanh")


def get_act_fn(act_fn_name: str) -> nn.Module:
"""Get an activation function by name."""
act_fn_name = act_fn_name.lower()
Expand All @@ -703,12 +706,12 @@ def get_act_fn(act_fn_name: str) -> nn.Module:
return _ACTIVATION_REGISTRY[act_fn_name]


_ACTIVATION_AND_MUL_REGISTRY = LazyDict(
_ACTIVATION_AND_MUL_REGISTRY: LazyDict[nn.Module] = LazyDict(
{
"gelu": lambda: GeluAndMul(),
"silu": lambda: SiluAndMul(),
"geglu": lambda: GeluAndMul(),
"swigluoai": lambda *args, **kwargs: SwigluOAIAndMul(*args, **kwargs),
"swigluoai": lambda: SwigluOAIAndMul(),
}
)

Expand Down
21 changes: 15 additions & 6 deletions vllm/model_executor/layers/attention/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionMetadata,
AttentionType,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
Expand Down Expand Up @@ -209,6 +210,7 @@ def __init__(
`self.kv_cache`.
"""
super().__init__()
sliding_window: int | None
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
Expand Down Expand Up @@ -335,7 +337,7 @@ def __init__(
cache_config.enable_prefix_caching = False

impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an AttentionImpl subclass
num_heads,
head_size,
scale,
Expand Down Expand Up @@ -576,7 +578,7 @@ def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_attn_backend(self) -> type[AttentionBackend]:
return self.attn_backend

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Block size may get updated after model loading, refresh it
block_size = vllm_config.cache_config.block_size
# Should not be called for enc-dec or encoder-only attention.
Expand Down Expand Up @@ -680,9 +682,16 @@ def get_attention_context(
extracted from the forward context.
"""
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[layer_name]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: AttentionMetadata
if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[layer_name]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][layer_name]
else:
attn_metadata = attn_metadata_raw
attn_layer: Attention | MLAAttention = forward_context.no_compile_layers[layer_name]
kv_cache = attn_layer.kv_cache
slot_mapping = forward_context.slot_mapping
Expand All @@ -708,7 +717,7 @@ def unified_kv_cache_update(
assert hasattr(attn_layer.impl, "do_kv_cache_update"), (
f"{attn_layer.impl.__class__.__name__} does not support kv cache update"
)
attn_layer.impl.do_kv_cache_update(
attn_layer.impl.do_kv_cache_update( # type: ignore[attr-defined]
attn_layer,
key,
value,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

@functools.lru_cache
def create_chunked_local_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
attention_chunk_size: int,
) -> type[AttentionBackend]:
prefix = f"ChunkedLocalAttention_{attention_chunk_size}_"
Expand Down
13 changes: 9 additions & 4 deletions vllm/model_executor/layers/attention/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _get_cross_slot_mapping(

@functools.lru_cache
def create_cross_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
prefix = "CrossAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
Expand All @@ -87,6 +87,7 @@ def build(
) -> AttentionMetadata:
new_metadata = copy(common_attn_metadata)
new_metadata.causal = False
assert new_metadata.encoder_seq_lens_cpu is not None
max_encoder_len = int(new_metadata.encoder_seq_lens_cpu.max())
new_metadata.max_seq_len = max_encoder_len
# Any computed tokens indicated decode step>1 (no chunked prefill)
Expand Down Expand Up @@ -118,7 +119,7 @@ def build(
self.device,
)
attn_metadata = super().build(common_prefix_len, new_metadata, fast_build)
attn_metadata.slot_mapping = slot_mapping
attn_metadata.slot_mapping = slot_mapping # type: ignore[attr-defined]
return attn_metadata

# NOTE(Lucas): we need a custom impl so we can use the slot-mapping computed by
Expand All @@ -144,8 +145,12 @@ def forward(
and key is not None
and value is not None
):
self.do_kv_cache_update(
layer, key, value, kv_cache, attn_metadata.slot_mapping
self.do_kv_cache_update( # type: ignore[attr-defined]
layer,
key,
value,
kv_cache,
attn_metadata.slot_mapping, # type: ignore[attr-defined]
)

return super().forward(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@functools.lru_cache
def create_encoder_only_attention_backend(
underlying_attn_backend: AttentionBackend,
underlying_attn_backend: type[AttentionBackend],
) -> type[AttentionBackend]:
prefix = "EncoderOnlyAttention_"
underlying_builder = underlying_attn_backend.get_builder_cls()
Expand Down Expand Up @@ -93,6 +93,6 @@ def __init__(
**kwargs,
)

def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec:
def get_kv_cache_spec(self, vllm_config: VllmConfig) -> KVCacheSpec | None:
# Does not need KV cache
return None
27 changes: 17 additions & 10 deletions vllm/model_executor/layers/attention/mla_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def __init__(
cache_config.enable_prefix_caching = False

impl_cls = cast(type[MLAAttentionImpl], self.attn_backend.get_impl_cls())
self.impl = impl_cls(
self.impl = impl_cls( # type: ignore[assignment] # impl_cls always returns an MLAAttentionImpl subclass
num_heads=self.num_heads,
head_size=self.head_size,
scale=self.scale,
Expand Down Expand Up @@ -485,16 +485,23 @@ def forward(

if self.use_direct_call:
forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata
if isinstance(attn_metadata, dict):
attn_metadata = attn_metadata[self.layer_name]
attn_metadata_raw = forward_context.attn_metadata
attn_metadata: MLACommonMetadata
if isinstance(attn_metadata_raw, dict):
attn_metadata = attn_metadata_raw[self.layer_name] # type: ignore[assignment]
elif isinstance(attn_metadata_raw, list):
# list[dict[str, AttentionMetadata]]: used in speculative decoding
# where [0] is the base-model (non-speculative) metadata dict.
attn_metadata = attn_metadata_raw[0][self.layer_name] # type: ignore[assignment]
else:
attn_metadata = attn_metadata_raw
self_kv_cache = self.kv_cache
slot_mapping = forward_context.slot_mapping

assert isinstance(slot_mapping, dict), (
f"Expected slot_mapping to be a dict, got {type(slot_mapping)}. "
)
self.impl.do_kv_cache_update(
self.impl.do_kv_cache_update( # type: ignore[attr-defined]
kv_c_normed,
k_pe,
self_kv_cache,
Expand Down Expand Up @@ -612,7 +619,7 @@ def forward_impl(
num_mha_tokens = q.size(0) - num_mqa_tokens

if num_mha_tokens > 0:
self.impl.forward_mha(
self.impl.forward_mha( # type: ignore[attr-defined]
q[num_mqa_tokens:],
k_c_normed[num_mqa_tokens:],
k_pe[num_mqa_tokens:],
Expand Down Expand Up @@ -695,7 +702,7 @@ def forward_impl(
# call decode attn
if not is_sparse_impl:
assert attn_metadata.decode is not None
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self)
attn_out, lse = self.impl.forward_mqa(mqa_q, kv_cache, attn_metadata, self) # type: ignore[attr-defined]

# correct dcp attn_out with lse.
if self.impl.dcp_world_size > 1:
Expand Down Expand Up @@ -1053,9 +1060,9 @@ class QueryLenSupport(Enum):
"AITER_MLA backends use aiter kernels instead."
)
elif current_platform.is_xpu():
from vllm._xpu_ops import xpu_ops as ops
from vllm._xpu_ops import xpu_ops

flash_attn_varlen_func = ops.flash_attn_varlen_func # type: ignore[no-redef]
flash_attn_varlen_func = xpu_ops.flash_attn_varlen_func # type: ignore[no-redef,attr-defined,assignment]


def dynamic_per_batched_tensor_quant(
Expand Down Expand Up @@ -1988,7 +1995,7 @@ def build(
assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
self._build_fi_prefill_wrappers(attn_metadata.prefill)

return attn_metadata
return attn_metadata # type: ignore[return-value]


def reorg_kvcache(
Expand Down
11 changes: 7 additions & 4 deletions vllm/model_executor/layers/fused_moe/all2all_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,17 +117,20 @@ def maybe_make_prepare_finalize(
"Detected DP deployment with no --enable-expert-parallel. "
"Falling back to AllGather+ReduceScatter dispatch/combine."
)
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
assert device_communicator.all2all_manager is not None
return make_moe_prepare_and_finalize_naive_dp_ep(
is_sequence_parallel=moe.moe_parallel_config.is_sequence_parallel,
num_dispatchers=(
get_ep_group().device_communicator.all2all_manager.world_size
),
num_dispatchers=(device_communicator.all2all_manager.world_size),
use_monolithic=use_monolithic,
)
else:
return make_moe_prepare_and_finalize_no_dp_ep(use_monolithic)

all2all_manager = get_ep_group().device_communicator.all2all_manager
device_communicator = get_ep_group().device_communicator
assert device_communicator is not None
all2all_manager = device_communicator.all2all_manager
assert all2all_manager is not None

prepare_finalize: FusedMoEPrepareAndFinalize | None = None
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch

from vllm.config import ParallelConfig, SchedulerConfig
from vllm.config.kernel import MoEBackend
from vllm.distributed import get_dp_group, get_pcp_group, get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
Expand Down Expand Up @@ -1193,7 +1194,7 @@ class FusedMoEConfig:
# Defaults to intermediate_size_per_partition if not specified.
intermediate_size_per_partition_unpadded: int | None = None

moe_backend: str = "auto"
moe_backend: MoEBackend = "auto"
max_num_tokens: int = SchedulerConfig.DEFAULT_MAX_NUM_BATCHED_TOKENS_FOR_BATCHED_DP
has_bias: bool = False
is_act_and_mul: bool = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,9 @@ def persistent_masked_m_silu_mul_quant(
DeepGemmQuantScaleFMT.UE8M0,
]

cuda_arch = current_platform.get_device_capability(
device_id=y.device.index
).to_int()
device_capability = current_platform.get_device_capability(device_id=y.device.index)
assert device_capability is not None
cuda_arch = device_capability.to_int()

if current_platform.is_cuda() and cuda_arch >= 80:
torch.ops._C.persistent_masked_m_silu_mul_quant(
Expand Down
13 changes: 8 additions & 5 deletions vllm/model_executor/layers/fused_moe/oracle/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.config.kernel import MoEBackend
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
FusedMoEConfig,
Expand Down Expand Up @@ -146,7 +147,7 @@ def backend_to_kernel_cls(
raise ValueError(f"Unknown MXFP4 MoE backend: {backend.value}")


def map_mxfp4_backend(runner_backend: str) -> Mxfp4MoeBackend:
def map_mxfp4_backend(runner_backend: MoEBackend) -> Mxfp4MoeBackend:
"""Map user's moe_backend string to Mxfp4MoeBackend."""
mapping: dict[str, Mxfp4MoeBackend] = {
"flashinfer_trtllm": Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_BF16,
Expand Down Expand Up @@ -201,10 +202,12 @@ def select_gpt_oss_mxfp4_moe_backend(
Select the primary MXFP4 MoE backend.
Note: Shape-specific fallbacks may still occur at runtime.
"""
triton_kernels_supported = has_triton_kernels() and (
9,
0,
) <= current_platform.get_device_capability() < (11, 0)
device_capability = current_platform.get_device_capability()
triton_kernels_supported = (
has_triton_kernels()
and device_capability is not None
and (9, 0) <= device_capability < (11, 0)
)

# LoRA: separate experts backend path
if config.is_lora_enabled:
Expand Down
Loading
Loading