Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/features/batch_invariance.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Batch invariance has been tested and verified on the following models:

- **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1`
- **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`, `Qwen/Qwen3-4B-AWQ`, `Qwen/Qwen3-8B-AWQ`
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`
- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`, `Qwen/Qwen3-30B-A3B-Thinking-2507-FP8`
- **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct`
- **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct`
- **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b`
Expand Down
28 changes: 28 additions & 0 deletions tests/model_executor/model_loader/test_reload.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,34 @@ def test_reload_lifecycle():
assert tensor.__dict__ == materialized_tensor.__dict__


def test_materialize_layer_preserves_non_meta_tensors():
"""Ensure that materialize_layer does not overwrite non meta tensors."""
layer = torch.nn.Linear(2, 3, bias=True)

# Create a non meta bias tensor and meta weight, which can happen with FP8
bias_values = torch.ones(3)
layer.bias.data.copy_(bias_values)
layer.weight = torch.nn.Parameter(layer.weight.data.to("meta"))

assert layer.weight.is_meta
assert not layer.bias.is_meta

# materialize the layer weights after the bias is initialized
info = LayerReloadingInfo(
restore_metadata=({}, {}),
restore_device=torch.device("cpu"),
)
materialize_layer(layer, info)

# Ensure the weight materialized off meta
assert not layer.weight.is_meta
assert layer.weight.device.type == "cpu"

# Ensure that the bias is (still) not meta and values are unchanged
assert not layer.bias.is_meta
assert torch.equal(layer.bias.data, bias_values)


def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
Expand Down
162 changes: 162 additions & 0 deletions tests/v1/attention/test_kv_head_stride_canonicalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Unit tests for canonicalize_singleton_dim_strides.

Background
----------
When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 1 KV head
per rank), PyTorch's is_contiguous() returns True for *any* stride on the
size-1 dimension. The KV cache allocator can therefore produce a tensor
where that singleton dim has stride = 1 element (2 bytes for bf16) instead
of the canonical product-of-remaining-dims value.

CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+)
requires all non-outermost strides to be multiples of 16 bytes. A 2-byte
stride triggers cudaErrorIllegalInstruction.

canonicalize_singleton_dim_strides() patches degenerate strides on all
size-1 dimensions via torch.as_strided — zero-copy.

The degenerate stride manifests at different positions in different backends:
- FlashInfer: stride(-3) after kv_cache.permute() → shape [..., 1, B, D]
- FlashAttention: stride(-2) after kv_cache.unbind(0) → shape [N, B, 1, D]
"""

import torch

from vllm.utils.torch_utils import canonicalize_singleton_dim_strides

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _inject_degenerate_stride(t: torch.Tensor, dim: int) -> torch.Tensor:
"""Return a view of t with a degenerate (stride=1) on a size-1 dim."""
assert t.shape[dim] == 1, f"dim {dim} must have size 1"
strides = list(t.stride())
strides[dim] = 1 # inject the bug
return t.as_strided(t.shape, strides)


# ---------------------------------------------------------------------------
# Tests: canonicalize_singleton_dim_strides
# ---------------------------------------------------------------------------


class TestCanonicalizeSingletonDimStrides:
def test_flashinfer_layout_dim_neg3(self):
"""FlashInfer path: degenerate stride at dim -3 (num_kv_heads)."""
# Shape after permute: [num_blocks, 2, num_kv_heads, block_size, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)

assert t_deg.stride(-3) == 1 # confirm degenerate
assert t_deg.is_contiguous() # PyTorch doesn't notice

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-3) == block_size * head_size # canonical = 2048
assert fixed.stride(-2) == head_size # inner dims unchanged
assert fixed.stride(-1) == 1

def test_flash_attn_layout_dim_neg2(self):
"""FlashAttention path: degenerate stride at dim -2 (num_kv_heads)."""
# Shape after unbind(0): [num_blocks, block_size, num_kv_heads, head_size]
num_blocks, block_size, head_size = 64, 16, 128
t = torch.zeros(num_blocks, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)

assert t_deg.stride(-2) == 1
assert t_deg.is_contiguous()

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(-2) == head_size # canonical = 128
assert fixed.stride(-1) == 1

def test_canonical_strides_returned_as_is(self):
"""No degenerate strides → same object returned (no copy, no new view)."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
result = canonicalize_singleton_dim_strides(t)
assert result is t

def test_multi_kv_heads_unchanged(self):
"""num_kv_heads > 1 → strides are already canonical → unchanged."""
t = torch.zeros(16, 2, 4, 16, 128, dtype=torch.bfloat16)
original_strides = t.stride()
result = canonicalize_singleton_dim_strides(t)
assert result.stride() == original_strides

def test_data_pointer_preserved(self):
"""Fix is zero-copy: same underlying storage."""
t = torch.zeros(8, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.data_ptr() == t_deg.data_ptr()
assert fixed.storage_offset() == t_deg.storage_offset()

def test_multiple_singleton_dims(self):
"""All size-1 dims with degenerate strides are fixed."""
# Shape: [1, 1, 8, 32] — two size-1 dims
t = torch.zeros(1, 1, 8, 32, dtype=torch.float16)
# Both size-1 dims get degenerate strides
t_deg = t.as_strided(t.shape, (1, 1, 32, 1)) # both leading dims = 1

fixed = canonicalize_singleton_dim_strides(t_deg)

assert fixed.stride(0) == 1 * 8 * 32 # canonical: 256
assert fixed.stride(1) == 1 * 8 * 32 # canonical: 256 (same since size-1)
assert fixed.stride(2) == 32
assert fixed.stride(3) == 1

def test_various_shapes_flashinfer(self):
"""Correctness across different block_size / head_size for FlashInfer layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128), (16, 256)]:
t = torch.zeros(8, 2, 1, block_size, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-3) == block_size * head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-3)={fixed.stride(-3)}"
)

def test_various_shapes_flash_attn(self):
"""Correctness across different shapes for FlashAttention layout."""
for block_size, head_size in [(16, 64), (16, 128), (32, 128)]:
t = torch.zeros(8, block_size, 1, head_size, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-2)
fixed = canonicalize_singleton_dim_strides(t_deg)
assert fixed.stride(-2) == head_size, (
f"Failed for block_size={block_size}, head_size={head_size}: "
f"got stride(-2)={fixed.stride(-2)}"
)

def test_tma_alignment_satisfied_after_fix_bf16(self):
"""After fix, all strides meet 16-byte TMA alignment for bf16."""
t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16)
t_deg = _inject_degenerate_stride(t, dim=-3)
fixed = canonicalize_singleton_dim_strides(t_deg)

element_size = fixed.element_size() # 2 bytes for bf16
for i, s in enumerate(fixed.stride()):
assert (s * element_size) % 16 == 0 or i == len(fixed.stride()) - 1, (
f"dim {i} stride {s} * {element_size} bytes not 16-byte aligned"
)

def test_non_contiguous_outer_dims_preserved(self):
"""Outer (non-size-1) non-contiguous strides are left unchanged."""
# Simulate cross-layer unified allocation: num_blocks stride is non-canonical
# but the inner dims should be fixed.
base = torch.zeros(200, 2, 1, 16, 128, dtype=torch.bfloat16)
# Slice every 2nd block → non-canonical outer stride
t_sliced = base[::2] # shape [100, 2, 1, 16, 128], stride[0] = 2*canonical
t_deg = _inject_degenerate_stride(t_sliced, dim=-3)

fixed = canonicalize_singleton_dim_strides(t_deg)

# Outer stride should be unchanged (not a size-1 dim)
assert fixed.stride(0) == t_sliced.stride(0)
# Inner degenerate stride should be fixed
assert fixed.stride(-3) == 16 * 128
8 changes: 6 additions & 2 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool:
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
# Disabled for now due to correctness issues:
# https://github.com/flashinfer-ai/flashinfer/issues/3197
"enable_flashinfer_autotune": False,
},
}
OPTIMIZATION_LEVEL_02 = {
Expand All @@ -229,7 +231,9 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool:
"use_inductor_graph_partition": False,
},
"kernel_config": {
"enable_flashinfer_autotune": True,
# Disabled for now due to correctness issues:
# https://github.com/flashinfer-ai/flashinfer/issues/3197
"enable_flashinfer_autotune": False,
},
}
OPTIMIZATION_LEVEL_03 = {
Expand Down
15 changes: 15 additions & 0 deletions vllm/model_executor/layers/fused_moe/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class MoEActivation(Enum):
# and produce output of shape [..., d]
SILU = "silu"
GELU = "gelu"
GELU_TANH = "gelu_tanh"
RELU2 = "relu2"
SWIGLUOAI = "swigluoai"
SWIGLUSTEP = "swiglustep"
Expand All @@ -24,6 +25,7 @@ class MoEActivation(Enum):
# NOTE: Non-gated activations require the "_no_mul" suffix to be present.
SILU_NO_MUL = "silu_no_mul"
GELU_NO_MUL = "gelu_no_mul"
GELU_TANH_NO_MUL = "gelu_tanh_no_mul"
RELU2_NO_MUL = "relu2_no_mul"

@property
Expand Down Expand Up @@ -53,6 +55,8 @@ def without_mul(self) -> "MoEActivation":
@classmethod
def from_str(cls, s: str) -> "MoEActivation":
"""Parse from string for backward compatibility."""
if s == "gelu_pytorch_tanh":
s = cls.GELU_TANH.value
for member in cls:
if member.value == s:
return member
Expand All @@ -64,17 +68,20 @@ def from_str(cls, s: str) -> "MoEActivation":
_CUSTOM_OP_NAMES: dict[MoEActivation, str] = {
MoEActivation.SILU: "silu_and_mul",
MoEActivation.GELU: "gelu_and_mul",
MoEActivation.GELU_TANH: "gelu_tanh_and_mul",
MoEActivation.SWIGLUOAI: "swigluoai_and_mul",
MoEActivation.SWIGLUSTEP: "swiglustep_and_mul",
MoEActivation.RELU2: "relu2",
MoEActivation.SILU_NO_MUL: "silu_and_mul",
MoEActivation.GELU_NO_MUL: "gelu_and_mul",
MoEActivation.GELU_TANH_NO_MUL: "gelu_tanh_and_mul",
MoEActivation.RELU2_NO_MUL: "relu2",
}

_WITHOUT_MUL: dict[MoEActivation, MoEActivation] = {
MoEActivation.SILU: MoEActivation.SILU_NO_MUL,
MoEActivation.GELU: MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH: MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL,
}

Expand Down Expand Up @@ -115,6 +122,12 @@ def apply_moe_activation(
torch.ops._C.silu_and_mul(output, input)
elif activation == MoEActivation.GELU:
torch.ops._C.gelu_and_mul(output, input)
elif activation == MoEActivation.GELU_TANH:
if hasattr(torch.ops._C, "gelu_tanh_and_mul"):
torch.ops._C.gelu_tanh_and_mul(output, input)
else:
gate, up = input.chunk(2, dim=-1)
output.copy_(F.gelu(gate, approximate="tanh") * up)
elif activation == MoEActivation.SWIGLUOAI:
torch.ops._C.swigluoai_and_mul(output, input)
elif activation == MoEActivation.SWIGLUSTEP:
Expand All @@ -127,6 +140,8 @@ def apply_moe_activation(
output.copy_(F.silu(input))
elif activation == MoEActivation.GELU_NO_MUL:
output.copy_(F.gelu(input))
elif activation == MoEActivation.GELU_TANH_NO_MUL:
output.copy_(F.gelu(input, approximate="tanh"))
elif activation == MoEActivation.RELU2_NO_MUL:
F.relu(input, inplace=True)
torch.square(input, out=output)
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/cpu_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def _gelu_and_mul(
MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x),
MoEActivation.SWIGLUOAI: _swigluoai_forward_native,
MoEActivation.GELU: _gelu_and_mul,
MoEActivation.GELU_TANH: (
lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh")
* x[..., x.shape[-1] // 2 :]
),
}


Expand Down
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
]

Expand Down Expand Up @@ -709,10 +710,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,9 +786,11 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_marlin_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,10 +613,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,10 +1941,12 @@ def _supports_activation(activation: MoEActivation) -> bool:
return activation in [
MoEActivation.SILU,
MoEActivation.GELU,
MoEActivation.GELU_TANH,
MoEActivation.SWIGLUOAI,
MoEActivation.SWIGLUSTEP,
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.GELU_TANH_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]

Expand Down
6 changes: 1 addition & 5 deletions vllm/model_executor/layers/quantization/moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import torch

from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
int4_w4a16_moe_quant_config,
Expand Down Expand Up @@ -372,17 +371,14 @@ def apply(
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts

assert layer.activation == MoEActivation.SILU, (
f"Only SiLU activation is supported, not {layer.activation}."
)

return fused_experts(
x,
layer.w13_qweight,
layer.w2_qweight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=not self.moe.disable_inplace,
activation=layer.activation,
apply_router_weight_on_input=layer.apply_router_weight_on_input,
global_num_experts=layer.global_num_experts,
expert_map=layer.expert_map,
Expand Down
Loading
Loading