From 54dc64d5d399d960b2f0bf7a00f1a90d16ca6178 Mon Sep 17 00:00:00 2001 From: Taneem Ibrahim Date: Sun, 3 May 2026 07:47:55 -0500 Subject: [PATCH 1/6] [Doc] Add Qwen3-30B-A3B-Thinking-2507-FP8 to batch invariance verified models (#41513) Signed-off-by: Taneem Ibrahim --- docs/features/batch_invariance.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/features/batch_invariance.md b/docs/features/batch_invariance.md index 804cd905e3b..b2363148450 100644 --- a/docs/features/batch_invariance.md +++ b/docs/features/batch_invariance.md @@ -105,7 +105,7 @@ Batch invariance has been tested and verified on the following models: - **DeepSeek series**: `deepseek-ai/DeepSeek-V3`, `deepseek-ai/DeepSeek-V3-0324`, `deepseek-ai/DeepSeek-R1`, `deepseek-ai/DeepSeek-V3.1` - **Qwen3 (Dense)**: `Qwen/Qwen3-1.7B`, `Qwen/Qwen3-8B`, `Qwen/Qwen3-4B-AWQ`, `Qwen/Qwen3-8B-AWQ` -- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct` +- **Qwen3 (MoE)**: `Qwen/Qwen3-30B-A3B`, `Qwen/Qwen3-Next-80B-A3B-Instruct`, `Qwen/Qwen3-30B-A3B-Thinking-2507-FP8` - **Qwen2.5**: `Qwen/Qwen2.5-0.5B-Instruct`, `Qwen/Qwen2.5-1.5B-Instruct`, `Qwen/Qwen2.5-3B-Instruct`, `Qwen/Qwen2.5-7B-Instruct`, `Qwen/Qwen2.5-14B-Instruct`, `Qwen/Qwen2.5-32B-Instruct` - **Llama 3**: `meta-llama/Llama-3.1-8B-Instruct`, `meta-llama/Llama-3.2-1B-Instruct` - **GPT-OSS**: `openai/gpt-oss-20b`, `openai/gpt-oss-120b` From c51df43005726a09c6eb7348e8c1b00501c70a8e Mon Sep 17 00:00:00 2001 From: Wei Zhao <51183510+wzhao18@users.noreply.github.com> Date: Sun, 3 May 2026 12:19:59 -0400 Subject: [PATCH 2/6] Disable flashinfer autotune temporarily due to correctness issues (#41524) Signed-off-by: wzhao18 --- vllm/config/vllm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 0146ee4c144..88e6660e216 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -209,7 +209,9 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "use_inductor_graph_partition": False, }, "kernel_config": { - "enable_flashinfer_autotune": True, + # Disabled for now due to correctness issues: + # https://github.com/flashinfer-ai/flashinfer/issues/3197 + "enable_flashinfer_autotune": False, }, } OPTIMIZATION_LEVEL_02 = { @@ -229,7 +231,9 @@ def enable_mla_dual_rms_norm_fusion(cfg: "VllmConfig") -> bool: "use_inductor_graph_partition": False, }, "kernel_config": { - "enable_flashinfer_autotune": True, + # Disabled for now due to correctness issues: + # https://github.com/flashinfer-ai/flashinfer/issues/3197 + "enable_flashinfer_autotune": False, }, } OPTIMIZATION_LEVEL_03 = { From cb03fee32b5c191b3ac5a248e47cd58a66c75591 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Sun, 3 May 2026 23:00:41 +0300 Subject: [PATCH 3/6] [Bugfix][Ray] Fix RayExecutorV2 actor name collision with DP > 1 (#40398) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- vllm/v1/engine/utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 7b0f00d14c8..1f0b9bbb19d 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -403,6 +403,11 @@ def __init__( range(dp_size), local_dp_ranks, placement_groups ): dp_vllm_config = copy.deepcopy(vllm_config) + if dp_size > 1: + # Append the DP rank to instance_id so that per-engine + # identifiers (e.g. Ray actor names in RayExecutorV2) are + # unique across DP replicas. + dp_vllm_config.instance_id = f"{dp_vllm_config.instance_id}_dp{index}" dp_vllm_config.parallel_config.placement_group = pg local_client = index < local_engine_count From db9a84e0cd0e17ab693467ff4a71103abd4b77bf Mon Sep 17 00:00:00 2001 From: Alex Brooks Date: Sun, 3 May 2026 14:30:04 -0600 Subject: [PATCH 4/6] [Bugfix] Fix FP8 Bias Loading (#41424) Signed-off-by: Alex Brooks --- .../model_loader/test_reload.py | 28 +++++++++++++++++++ .../model_loader/reload/meta.py | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/model_executor/model_loader/test_reload.py b/tests/model_executor/model_loader/test_reload.py index 6e3e2d63e14..cf3553bd57d 100644 --- a/tests/model_executor/model_loader/test_reload.py +++ b/tests/model_executor/model_loader/test_reload.py @@ -59,6 +59,34 @@ def test_reload_lifecycle(): assert tensor.__dict__ == materialized_tensor.__dict__ +def test_materialize_layer_preserves_non_meta_tensors(): + """Ensure that materialize_layer does not overwrite non meta tensors.""" + layer = torch.nn.Linear(2, 3, bias=True) + + # Create a non meta bias tensor and meta weight, which can happen with FP8 + bias_values = torch.ones(3) + layer.bias.data.copy_(bias_values) + layer.weight = torch.nn.Parameter(layer.weight.data.to("meta")) + + assert layer.weight.is_meta + assert not layer.bias.is_meta + + # materialize the layer weights after the bias is initialized + info = LayerReloadingInfo( + restore_metadata=({}, {}), + restore_device=torch.device("cpu"), + ) + materialize_layer(layer, info) + + # Ensure the weight materialized off meta + assert not layer.weight.is_meta + assert layer.weight.device.type == "cpu" + + # Ensure that the bias is (still) not meta and values are unchanged + assert not layer.bias.is_meta + assert torch.equal(layer.bias.data, bias_values) + + def test_model_cleanup(dist_init, default_vllm_config): layer = QKVParallelLinear(2, 3, 4) assert layer.weight.weight_loader.__self__ is layer diff --git a/vllm/model_executor/model_loader/reload/meta.py b/vllm/model_executor/model_loader/reload/meta.py index 91fce6f57b3..baa2081d58b 100644 --- a/vllm/model_executor/model_loader/reload/meta.py +++ b/vllm/model_executor/model_loader/reload/meta.py @@ -102,7 +102,7 @@ def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo): with info.restore_device: for name, tensor in get_layer_tensors(layer).items(): - if name not in SKIP_TENSORS: + if name not in SKIP_TENSORS and tensor.is_meta: setattr(layer, name, materialize_meta_tensor(tensor)) From 66dfee7121dfcbdfcce04ce92c117e5e10b25b14 Mon Sep 17 00:00:00 2001 From: David Oy <58150256+the-david-oy@users.noreply.github.com> Date: Sun, 3 May 2026 16:52:18 -0700 Subject: [PATCH 5/6] [Bugfix] Fix degenerate KV cache stride causing TMA cudaErrorIllegalInstruction (#40737) Signed-off-by: David Oy Signed-off-by: David Oy <58150256+the-david-oy@users.noreply.github.com> Signed-off-by: David Oy Co-authored-by: David Oy Co-authored-by: Claude Co-authored-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com> --- .../test_kv_head_stride_canonicalization.py | 162 ++++++++++++++++++ vllm/utils/cpu_resource_utils.py | 2 +- vllm/utils/torch_utils.py | 26 +++ vllm/v1/attention/backends/flash_attn.py | 24 ++- .../attention/backends/flash_attn_diffkv.py | 25 ++- vllm/v1/attention/backends/flashinfer.py | 50 ++++-- 6 files changed, 267 insertions(+), 22 deletions(-) create mode 100644 tests/v1/attention/test_kv_head_stride_canonicalization.py diff --git a/tests/v1/attention/test_kv_head_stride_canonicalization.py b/tests/v1/attention/test_kv_head_stride_canonicalization.py new file mode 100644 index 00000000000..635f46390cf --- /dev/null +++ b/tests/v1/attention/test_kv_head_stride_canonicalization.py @@ -0,0 +1,162 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for canonicalize_singleton_dim_strides. + +Background +---------- +When num_kv_heads_per_rank == 1 (e.g. Qwen3.5-397B with TP=8 → 1 KV head +per rank), PyTorch's is_contiguous() returns True for *any* stride on the +size-1 dimension. The KV cache allocator can therefore produce a tensor +where that singleton dim has stride = 1 element (2 bytes for bf16) instead +of the canonical product-of-remaining-dims value. + +CUDA TMA (used by FlashInfer XQA SM90 and Flash-Attention 3/4 on H100+) +requires all non-outermost strides to be multiples of 16 bytes. A 2-byte +stride triggers cudaErrorIllegalInstruction. + +canonicalize_singleton_dim_strides() patches degenerate strides on all +size-1 dimensions via torch.as_strided — zero-copy. + +The degenerate stride manifests at different positions in different backends: +- FlashInfer: stride(-3) after kv_cache.permute() → shape [..., 1, B, D] +- FlashAttention: stride(-2) after kv_cache.unbind(0) → shape [N, B, 1, D] +""" + +import torch + +from vllm.utils.torch_utils import canonicalize_singleton_dim_strides + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _inject_degenerate_stride(t: torch.Tensor, dim: int) -> torch.Tensor: + """Return a view of t with a degenerate (stride=1) on a size-1 dim.""" + assert t.shape[dim] == 1, f"dim {dim} must have size 1" + strides = list(t.stride()) + strides[dim] = 1 # inject the bug + return t.as_strided(t.shape, strides) + + +# --------------------------------------------------------------------------- +# Tests: canonicalize_singleton_dim_strides +# --------------------------------------------------------------------------- + + +class TestCanonicalizeSingletonDimStrides: + def test_flashinfer_layout_dim_neg3(self): + """FlashInfer path: degenerate stride at dim -3 (num_kv_heads).""" + # Shape after permute: [num_blocks, 2, num_kv_heads, block_size, head_size] + num_blocks, block_size, head_size = 64, 16, 128 + t = torch.zeros(num_blocks, 2, 1, block_size, head_size, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-3) + + assert t_deg.stride(-3) == 1 # confirm degenerate + assert t_deg.is_contiguous() # PyTorch doesn't notice + + fixed = canonicalize_singleton_dim_strides(t_deg) + + assert fixed.stride(-3) == block_size * head_size # canonical = 2048 + assert fixed.stride(-2) == head_size # inner dims unchanged + assert fixed.stride(-1) == 1 + + def test_flash_attn_layout_dim_neg2(self): + """FlashAttention path: degenerate stride at dim -2 (num_kv_heads).""" + # Shape after unbind(0): [num_blocks, block_size, num_kv_heads, head_size] + num_blocks, block_size, head_size = 64, 16, 128 + t = torch.zeros(num_blocks, block_size, 1, head_size, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-2) + + assert t_deg.stride(-2) == 1 + assert t_deg.is_contiguous() + + fixed = canonicalize_singleton_dim_strides(t_deg) + + assert fixed.stride(-2) == head_size # canonical = 128 + assert fixed.stride(-1) == 1 + + def test_canonical_strides_returned_as_is(self): + """No degenerate strides → same object returned (no copy, no new view).""" + t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16) + result = canonicalize_singleton_dim_strides(t) + assert result is t + + def test_multi_kv_heads_unchanged(self): + """num_kv_heads > 1 → strides are already canonical → unchanged.""" + t = torch.zeros(16, 2, 4, 16, 128, dtype=torch.bfloat16) + original_strides = t.stride() + result = canonicalize_singleton_dim_strides(t) + assert result.stride() == original_strides + + def test_data_pointer_preserved(self): + """Fix is zero-copy: same underlying storage.""" + t = torch.zeros(8, 2, 1, 16, 128, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-3) + fixed = canonicalize_singleton_dim_strides(t_deg) + assert fixed.data_ptr() == t_deg.data_ptr() + assert fixed.storage_offset() == t_deg.storage_offset() + + def test_multiple_singleton_dims(self): + """All size-1 dims with degenerate strides are fixed.""" + # Shape: [1, 1, 8, 32] — two size-1 dims + t = torch.zeros(1, 1, 8, 32, dtype=torch.float16) + # Both size-1 dims get degenerate strides + t_deg = t.as_strided(t.shape, (1, 1, 32, 1)) # both leading dims = 1 + + fixed = canonicalize_singleton_dim_strides(t_deg) + + assert fixed.stride(0) == 1 * 8 * 32 # canonical: 256 + assert fixed.stride(1) == 1 * 8 * 32 # canonical: 256 (same since size-1) + assert fixed.stride(2) == 32 + assert fixed.stride(3) == 1 + + def test_various_shapes_flashinfer(self): + """Correctness across different block_size / head_size for FlashInfer layout.""" + for block_size, head_size in [(16, 64), (16, 128), (32, 128), (16, 256)]: + t = torch.zeros(8, 2, 1, block_size, head_size, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-3) + fixed = canonicalize_singleton_dim_strides(t_deg) + assert fixed.stride(-3) == block_size * head_size, ( + f"Failed for block_size={block_size}, head_size={head_size}: " + f"got stride(-3)={fixed.stride(-3)}" + ) + + def test_various_shapes_flash_attn(self): + """Correctness across different shapes for FlashAttention layout.""" + for block_size, head_size in [(16, 64), (16, 128), (32, 128)]: + t = torch.zeros(8, block_size, 1, head_size, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-2) + fixed = canonicalize_singleton_dim_strides(t_deg) + assert fixed.stride(-2) == head_size, ( + f"Failed for block_size={block_size}, head_size={head_size}: " + f"got stride(-2)={fixed.stride(-2)}" + ) + + def test_tma_alignment_satisfied_after_fix_bf16(self): + """After fix, all strides meet 16-byte TMA alignment for bf16.""" + t = torch.zeros(64, 2, 1, 16, 128, dtype=torch.bfloat16) + t_deg = _inject_degenerate_stride(t, dim=-3) + fixed = canonicalize_singleton_dim_strides(t_deg) + + element_size = fixed.element_size() # 2 bytes for bf16 + for i, s in enumerate(fixed.stride()): + assert (s * element_size) % 16 == 0 or i == len(fixed.stride()) - 1, ( + f"dim {i} stride {s} * {element_size} bytes not 16-byte aligned" + ) + + def test_non_contiguous_outer_dims_preserved(self): + """Outer (non-size-1) non-contiguous strides are left unchanged.""" + # Simulate cross-layer unified allocation: num_blocks stride is non-canonical + # but the inner dims should be fixed. + base = torch.zeros(200, 2, 1, 16, 128, dtype=torch.bfloat16) + # Slice every 2nd block → non-canonical outer stride + t_sliced = base[::2] # shape [100, 2, 1, 16, 128], stride[0] = 2*canonical + t_deg = _inject_degenerate_stride(t_sliced, dim=-3) + + fixed = canonicalize_singleton_dim_strides(t_deg) + + # Outer stride should be unchanged (not a size-1 dim) + assert fixed.stride(0) == t_sliced.stride(0) + # Inner degenerate stride should be fixed + assert fixed.stride(-3) == 16 * 128 diff --git a/vllm/utils/cpu_resource_utils.py b/vllm/utils/cpu_resource_utils.py index bbf554d0ccd..25c299a0c0c 100644 --- a/vllm/utils/cpu_resource_utils.py +++ b/vllm/utils/cpu_resource_utils.py @@ -125,7 +125,7 @@ def get_allowed_cpu_list() -> list[LogicalCPUInfo]: if platform.system() == "Darwin": return cpu_list - global_allowed_cpu_id_list = os.sched_getaffinity(0) + global_allowed_cpu_id_list = os.sched_getaffinity(0) # type: ignore[attr-defined] logical_cpu_list = [x for x in cpu_list if x.id in global_allowed_cpu_id_list] return logical_cpu_list diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 1eb9306ed4b..798c136fc23 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -110,6 +110,32 @@ def is_strictly_contiguous(t: torch.Tensor) -> bool: return True +def canonicalize_singleton_dim_strides(t: torch.Tensor) -> torch.Tensor: + """Fix degenerate strides on size=1 dimensions for CUDA TMA compatibility. + + PyTorch allows any stride on a size=1 dim (is_contiguous() is always True + there), so a size=1 dim may have stride=1 (2 bytes for bf16) instead of + the canonical product(shape[i+1:]). CUDA TMA on H100+ requires all + non-outermost strides to be ≥16-byte aligned; stride=1 triggers + cudaErrorIllegalInstruction. Zero-copy: patches stride metadata only via + as_strided; returns t unchanged if all size=1 strides are already canonical. + """ + if 1 not in t.shape: + return t + strides = list(t.stride()) + shape = t.shape + prev_stride = 1 + changed = False + for i in range(len(shape) - 1, -1, -1): + if shape[i] == 1 and strides[i] != prev_stride: + strides[i] = prev_stride + changed = True + prev_stride = strides[i] * shape[i] + if not changed: + return t + return t.as_strided(t.shape, strides) + + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 1c9ff3f79e4..e73954ee747 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -11,7 +11,10 @@ from vllm.model_executor.layers.attention import Attention from vllm.platforms import current_platform -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.utils.torch_utils import ( + canonicalize_singleton_dim_strides, + is_quantized_kv_cache, +) from vllm.v1.attention.backend import ( AttentionBackend, AttentionImpl, @@ -747,6 +750,23 @@ def forward( # For decoder and cross-attention, use KV cache as before key_cache, value_cache = kv_cache.unbind(0) + # Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP). + # FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment. + # See vllm.utils.torch_utils.canonicalize_singleton_dim_strides. + fixed_k = canonicalize_singleton_dim_strides(key_cache) + fixed_v = canonicalize_singleton_dim_strides(value_cache) + if fixed_k is not key_cache or fixed_v is not value_cache: + logger.debug( + "Canonicalized degenerate KV cache strides (FlashAttention): " + "shape=%s, key strides before=%s after=%s, " + "value strides before=%s after=%s", + key_cache.shape, + key_cache.stride(), + fixed_k.stride(), + value_cache.stride(), + fixed_v.stride(), + ) + key_cache, value_cache = fixed_k, fixed_v if is_quantized_kv_cache(self.kv_cache_dtype): # queries are quantized in the attention layer @@ -861,6 +881,8 @@ def do_kv_cache_update( # we use direct Q, K, V tensors without caching return + # Scatter write into the KV cache using slot_mapping indices. + # No TMA kernel is invoked here, so stride canonicalization is not needed. key_cache, value_cache = kv_cache.unbind(0) # Reshape the input keys and values and store them in the cache. diff --git a/vllm/v1/attention/backends/flash_attn_diffkv.py b/vllm/v1/attention/backends/flash_attn_diffkv.py index d1805476971..82a9f07a4e5 100644 --- a/vllm/v1/attention/backends/flash_attn_diffkv.py +++ b/vllm/v1/attention/backends/flash_attn_diffkv.py @@ -4,7 +4,11 @@ import torch -from vllm.utils.torch_utils import is_quantized_kv_cache +from vllm.logger import init_logger +from vllm.utils.torch_utils import ( + canonicalize_singleton_dim_strides, + is_quantized_kv_cache, +) from vllm.v1.attention.backend import AttentionType from vllm.v1.attention.backends.fa_utils import ( get_flash_attn_version, @@ -25,6 +29,8 @@ cascade_attention, ) +logger = init_logger(__name__) + class FlashAttentionDiffKVBackend(FlashAttentionBackend): # Default to 128 for this backend @@ -204,6 +210,23 @@ def forward( # Different head_size for K and V key_cache = kv_cache[..., : self.head_size] value_cache = kv_cache[..., self.head_size :] + # Fix degenerate strides on size-1 dims (e.g. num_kv_heads=1 with TP). + # FA3/4 on H100+ uses TMA, which requires ≥16-byte stride alignment. + # See vllm.utils.torch_utils.canonicalize_singleton_dim_strides. + fixed_k = canonicalize_singleton_dim_strides(key_cache) + fixed_v = canonicalize_singleton_dim_strides(value_cache) + if fixed_k is not key_cache or fixed_v is not value_cache: + logger.debug( + "Canonicalized degenerate KV cache strides (FlashAttentionDiffKV): " + "shape=%s, key strides before=%s after=%s, " + "value strides before=%s after=%s", + key_cache.shape, + key_cache.stride(), + fixed_k.stride(), + value_cache.stride(), + fixed_v.stride(), + ) + key_cache, value_cache = fixed_k, fixed_v if is_quantized_kv_cache(self.kv_cache_dtype): # queries are quantized in the attention layer diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 8f5cb6206bd..2de61a2b1f2 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -43,6 +43,7 @@ from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import ( + canonicalize_singleton_dim_strides, is_quantized_kv_cache, is_strictly_contiguous, nvfp4_kv_cache_full_dim, @@ -1479,6 +1480,21 @@ def forward( stride_order = FlashInferBackend.get_kv_cache_stride_order() kv_cache_permute = kv_cache.permute(*stride_order) # HND and contiguous + # Fix degenerate strides on any size-1 dimension (e.g. num_kv_heads=1 + # with TP=8). PyTorch permits non-canonical strides on size-1 dims; + # CUDA TMA requires ≥16-byte alignment on all non-outermost strides. + # canonicalize_singleton_dim_strides patches metadata via as_strided — + # zero-copy. See vllm.utils.torch_utils. + fixed = canonicalize_singleton_dim_strides(kv_cache_permute) + if fixed is not kv_cache_permute: + logger.debug( + "Canonicalized degenerate KV cache strides (FlashInfer): " + "shape=%s, strides before=%s, strides after=%s", + kv_cache_permute.shape, + kv_cache_permute.stride(), + fixed.stride(), + ) + kv_cache_permute = fixed # For NVFP4, the kv_cache last dim is full_dim (data + scale packed). # Split into correctly-strided data and scale views. @@ -1568,10 +1584,11 @@ def forward( else: assert isinstance(attn_metadata.prefill, TRTLLMPrefill) # prefill_query may be non-contiguous or have degenerate strides - # First ensure memory contiguity, then fix degenerate strides - # with reshape. contiguous() alone doesn't fix degenerate - # strides when a dimension has size 1. - prefill_query = prefill_query.contiguous().reshape(prefill_query.shape) + # on size=1 dims. contiguous() ensures memory layout; then + # canonicalize_singleton_dim_strides fixes any remaining + # degenerate strides on size=1 dims for TMA alignment. + prefill_query = prefill_query.contiguous() + prefill_query = canonicalize_singleton_dim_strides(prefill_query) workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_prefill = attn_metadata.prefill.block_tables seq_lens_prefill = attn_metadata.prefill.seq_lens @@ -1621,11 +1638,9 @@ def forward( # with fp8 kv cache, we can construct a mock block # and mock kv cache with BF16 KV involved in the prefill # - # The inner (block_size, head_size) dims must be - # contiguous; outer dims may have non-canonical strides - # (e.g. cross-layer unified allocation). - # Degenerate strides on outer dims break TMA descriptors - # (see flashinfer-ai/flashinfer#2232). + kv_cache_permute = canonicalize_singleton_dim_strides( + kv_cache_permute + ) kv_strides = kv_cache_permute.stride() assert ( kv_strides[-1] == 1 @@ -1732,12 +1747,13 @@ def forward( if needs_fp8_out: output[:num_decode_tokens].copy_(out_decode.to(output.dtype)) else: - # decode_query may be non-contiguous or have degenerate strides assert isinstance(attn_metadata.decode, TRTLLMDecode) - # First ensure memory contiguity, then fix degenerate strides - # with reshape. contiguous() alone doesn't fix degenerate - # strides when a dimension has size 1. - decode_query = decode_query.contiguous().reshape(decode_query.shape) + # decode_query may be non-contiguous or have degenerate strides + # on size=1 dims. contiguous() ensures memory layout; then + # canonicalize_singleton_dim_strides fixes any remaining + # degenerate strides on size=1 dims for TMA alignment. + decode_query = decode_query.contiguous() + decode_query = canonicalize_singleton_dim_strides(decode_query) workspace_buffer = _get_trtllm_gen_workspace_buffer() block_tables_decode = attn_metadata.decode.block_tables seq_lens_decode = attn_metadata.decode.seq_lens @@ -1748,11 +1764,7 @@ def forward( assert is_strictly_contiguous(workspace_buffer) assert is_strictly_contiguous(block_tables_decode) assert is_strictly_contiguous(seq_lens_decode) - # kv_cache outer dims may be non-contiguous (e.g. - # cross-layer unified allocation), but inner dims - # (block_size, head_size) must be contiguous and - # strides must be canonical to avoid TMA descriptor - # failures (see flashinfer-ai/flashinfer#2232). + kv_cache_permute = canonicalize_singleton_dim_strides(kv_cache_permute) kv_strides = kv_cache_permute.stride() assert ( kv_strides[-1] == 1 and kv_strides[-2] == kv_cache_permute.shape[-1] From 93e4ec6313a85ae5c4a545ea4a55e26493d2f56a Mon Sep 17 00:00:00 2001 From: lesj0610 Date: Mon, 20 Apr 2026 16:19:52 +0900 Subject: [PATCH 6/6] Restack Gemma4 core: fused MoE activation support Signed-off-by: lesj0610 --- .../model_executor/layers/fused_moe/activation.py | 15 +++++++++++++++ .../layers/fused_moe/cpu_fused_moe.py | 4 ++++ .../layers/fused_moe/experts/cutlass_moe.py | 3 +++ .../layers/fused_moe/fused_batched_moe.py | 2 ++ .../layers/fused_moe/fused_marlin_moe.py | 2 ++ vllm/model_executor/layers/fused_moe/fused_moe.py | 2 ++ .../layers/quantization/moe_wna16.py | 6 +----- 7 files changed, 29 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/activation.py b/vllm/model_executor/layers/fused_moe/activation.py index 3112b3054fc..7a013c8f28a 100644 --- a/vllm/model_executor/layers/fused_moe/activation.py +++ b/vllm/model_executor/layers/fused_moe/activation.py @@ -15,6 +15,7 @@ class MoEActivation(Enum): # and produce output of shape [..., d] SILU = "silu" GELU = "gelu" + GELU_TANH = "gelu_tanh" RELU2 = "relu2" SWIGLUOAI = "swigluoai" SWIGLUSTEP = "swiglustep" @@ -24,6 +25,7 @@ class MoEActivation(Enum): # NOTE: Non-gated activations require the "_no_mul" suffix to be present. SILU_NO_MUL = "silu_no_mul" GELU_NO_MUL = "gelu_no_mul" + GELU_TANH_NO_MUL = "gelu_tanh_no_mul" RELU2_NO_MUL = "relu2_no_mul" @property @@ -53,6 +55,8 @@ def without_mul(self) -> "MoEActivation": @classmethod def from_str(cls, s: str) -> "MoEActivation": """Parse from string for backward compatibility.""" + if s == "gelu_pytorch_tanh": + s = cls.GELU_TANH.value for member in cls: if member.value == s: return member @@ -64,17 +68,20 @@ def from_str(cls, s: str) -> "MoEActivation": _CUSTOM_OP_NAMES: dict[MoEActivation, str] = { MoEActivation.SILU: "silu_and_mul", MoEActivation.GELU: "gelu_and_mul", + MoEActivation.GELU_TANH: "gelu_tanh_and_mul", MoEActivation.SWIGLUOAI: "swigluoai_and_mul", MoEActivation.SWIGLUSTEP: "swiglustep_and_mul", MoEActivation.RELU2: "relu2", MoEActivation.SILU_NO_MUL: "silu_and_mul", MoEActivation.GELU_NO_MUL: "gelu_and_mul", + MoEActivation.GELU_TANH_NO_MUL: "gelu_tanh_and_mul", MoEActivation.RELU2_NO_MUL: "relu2", } _WITHOUT_MUL: dict[MoEActivation, MoEActivation] = { MoEActivation.SILU: MoEActivation.SILU_NO_MUL, MoEActivation.GELU: MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH: MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2: MoEActivation.RELU2_NO_MUL, } @@ -115,6 +122,12 @@ def apply_moe_activation( torch.ops._C.silu_and_mul(output, input) elif activation == MoEActivation.GELU: torch.ops._C.gelu_and_mul(output, input) + elif activation == MoEActivation.GELU_TANH: + if hasattr(torch.ops._C, "gelu_tanh_and_mul"): + torch.ops._C.gelu_tanh_and_mul(output, input) + else: + gate, up = input.chunk(2, dim=-1) + output.copy_(F.gelu(gate, approximate="tanh") * up) elif activation == MoEActivation.SWIGLUOAI: torch.ops._C.swigluoai_and_mul(output, input) elif activation == MoEActivation.SWIGLUSTEP: @@ -127,6 +140,8 @@ def apply_moe_activation( output.copy_(F.silu(input)) elif activation == MoEActivation.GELU_NO_MUL: output.copy_(F.gelu(input)) + elif activation == MoEActivation.GELU_TANH_NO_MUL: + output.copy_(F.gelu(input, approximate="tanh")) elif activation == MoEActivation.RELU2_NO_MUL: F.relu(input, inplace=True) torch.square(input, out=output) diff --git a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py index 985f33e1009..1cbe5205bfa 100644 --- a/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/cpu_fused_moe.py @@ -48,6 +48,10 @@ def _gelu_and_mul( MoEActivation.SILU: lambda x: SiluAndMul(compile_native=False).forward_native(x), MoEActivation.SWIGLUOAI: _swigluoai_forward_native, MoEActivation.GELU: _gelu_and_mul, + MoEActivation.GELU_TANH: ( + lambda x: F.gelu(x[..., : x.shape[-1] // 2], approximate="tanh") + * x[..., x.shape[-1] // 2 :] + ), } diff --git a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py index fdd802e7da3..7c0d0d8d177 100644 --- a/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/cutlass_moe.py @@ -319,6 +319,7 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, ] @@ -709,10 +710,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index bd54cd636b0..8638a11466c 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -786,9 +786,11 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index ebd33019709..3487ac1766e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -613,10 +613,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7e7bcc70992..1a655934259 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1941,10 +1941,12 @@ def _supports_activation(activation: MoEActivation) -> bool: return activation in [ MoEActivation.SILU, MoEActivation.GELU, + MoEActivation.GELU_TANH, MoEActivation.SWIGLUOAI, MoEActivation.SWIGLUSTEP, MoEActivation.SILU_NO_MUL, MoEActivation.GELU_NO_MUL, + MoEActivation.GELU_TANH_NO_MUL, MoEActivation.RELU2_NO_MUL, ] diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index e5ef3f4c316..d67b386b933 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,7 +6,6 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group -from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, int4_w4a16_moe_quant_config, @@ -372,10 +371,6 @@ def apply( ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts - assert layer.activation == MoEActivation.SILU, ( - f"Only SiLU activation is supported, not {layer.activation}." - ) - return fused_experts( x, layer.w13_qweight, @@ -383,6 +378,7 @@ def apply( topk_weights=topk_weights, topk_ids=topk_ids, inplace=not self.moe.disable_inplace, + activation=layer.activation, apply_router_weight_on_input=layer.apply_router_weight_on_input, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map,