Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError

Expand Down Expand Up @@ -372,7 +373,7 @@
raise NotImplementedError

@abstractmethod
def forward(

Check failure on line 376 in vllm/attention/backends/abstract.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "AttentionImpl" [override]

Check failure on line 376 in vllm/attention/backends/abstract.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "AttentionImpl" [override]

Check failure on line 376 in vllm/attention/backends/abstract.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "AttentionImpl" [override]

Check failure on line 376 in vllm/attention/backends/abstract.py

View workflow job for this annotation

GitHub Actions / pre-commit

Signature of "forward" incompatible with supertype "AttentionImpl" [override]
self,
layer: AttentionLayer,
hidden_states_or_cq: torch.Tensor,
Expand Down
77 changes: 63 additions & 14 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@

if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9

if envs.VLLM_ROCM_USE_AITER:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.

VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
else:
on_gfx9 = lambda *args, **kwargs: False
Comment on lines 46 to 56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The variable VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is only defined if current_platform.is_rocm() is true. This will cause a NameError on other platforms (e.g., CUDA) where this variable is used later in unified_attention_with_output. The definition should be refactored to ensure it's always defined, regardless of the platform.

Suggested change
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
if envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
else:
on_gfx9 = lambda *args, **kwargs: False
if current_platform.is_rocm():
from vllm.platforms.rocm import on_gfx9
else:
on_gfx9 = lambda *args, **kwargs: False
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False


Expand Down Expand Up @@ -235,6 +242,7 @@
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: str | None = None,
attn_backend: type[AttentionBackend] | None = None,
rotary_emb: nn.Module | None = None,
**extra_impl_args,
) -> None:
"""
Expand Down Expand Up @@ -310,6 +318,7 @@
kv_sharing_target_layer_name,
**extra_impl_args,
)
self.impl.rotary_emb = rotary_emb

Check failure on line 321 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionImpl[Any]" has no attribute "rotary_emb" [attr-defined]

Check failure on line 321 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionImpl[Any]" has no attribute "rotary_emb" [attr-defined]

Check failure on line 321 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionImpl[Any]" has no attribute "rotary_emb" [attr-defined]

Check failure on line 321 in vllm/attention/layer.py

View workflow job for this annotation

GitHub Actions / pre-commit

"AttentionImpl[Any]" has no attribute "rotary_emb" [attr-defined]
self.backend = AttentionBackendEnum[self.attn_backend.get_name()]
self.dtype = dtype

Expand Down Expand Up @@ -365,6 +374,7 @@
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: torch.Size | None = None,
positions: torch.Tensor = None,
) -> torch.Tensor:
"""
The KV cache is stored inside this class and is accessed via
Expand All @@ -377,7 +387,6 @@
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
# torch.compile to fuse this into previous ops
Expand All @@ -392,7 +401,15 @@

if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
output = torch.empty(output_shape, dtype=output_dtype, device=query.device)
if positions is not None:
output = torch.empty(
output_shape, dtype=query.dtype, device=query.device
)
else:
output = torch.zeros(
output_shape, dtype=query.dtype, device=query.device
Comment on lines 402 to +410

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Output buffer now allocated in quantized FP8 dtype

In the attention forward path, the output tensor is now created with dtype=query.dtype (lines 402‑410). When FP8 query quantization is active, self.query_quant converts query to an FP8 tensor before this allocation. The previous code cached the pre‑quantization dtype (output_dtype) so the output buffer remained fp16/bf16. After this change the output is allocated in FP8, but downstream attention kernels expect the regular activation dtype, so the call either fails or produces incorrect results whenever query quantization is enabled. Capture the original dtype before quantizing and use it for output to avoid creating an FP8 output buffer.

Useful? React with 👍 / 👎.

)

hidden_size = output_shape[-1]
# Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
Expand All @@ -414,7 +431,13 @@
)
else:
torch.ops.vllm.unified_attention_with_output(
query, key, value, output, self.layer_name
query,
key,
value,
output,
self.layer_name,
None,
positions=positions,
)
return output.view(-1, hidden_size)
else:
Expand Down Expand Up @@ -941,19 +964,44 @@
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> None:
attn_metadata, self, kv_cache = get_attention_context(layer_name)
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)
from vllm.v1.attention.backends.rocm_aiter_fa import AiterFlashAttentionImpl

if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE and isinstance(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_xxx_enabled() and some new flags there.

self.impl, AiterFlashAttentionImpl
):
# fusing RoPE with flushing kv_cache operation
assert (
hasattr(self.impl, "rotary_emb")
and self.impl.rotary_emb is not None
and positions is not None
), f"rotary_emb not found in {self.impl=} and positions cannot be None"
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
positions=positions,
)
else:
assert positions is None, f"positions must be None {positions=}"
self.impl.forward(
self,
query,
key,
value,
kv_cache,
attn_metadata,
output=output,
output_scale=output_scale,
output_block_scale=output_block_scale,
)


def unified_attention_with_output_fake(
Expand All @@ -964,6 +1012,7 @@
layer_name: str,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> None:
return

Expand Down
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@
VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8: bool = False
VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: bool = False
VLLM_ROCM_FP8_MFMA_PAGE_ATTN: bool = False
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE: bool = True
Copy link
Collaborator

@tjtanaa tjtanaa Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw that this is enabled default.
Does this apply to all models?
If this can be applied to all models, do we see improvement in general?
If it does, maybe we don't need a flag to manage it, just have a logic where when aiter is enabled, we use the fusion op.

VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8_CUTLASS: bool = False
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: str | None = None
Expand Down Expand Up @@ -1393,6 +1394,10 @@ def get_vllm_port() -> int | None:
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN": lambda: bool(
int(os.getenv("VLLM_ROCM_FP8_MFMA_PAGE_ATTN", "0"))
),
# Use AITER Triton fused RoPE, zeros, and reshape_and_cache kernel
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE": lambda: bool(
int(os.getenv("VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE", "1"))
),
# Whether to use pytorch symmetric memory for allreduce
"VLLM_ALLREDUCE_USE_SYMM_MEM": lambda: bool(
int(os.getenv("VLLM_ALLREDUCE_USE_SYMM_MEM", "1"))
Expand Down Expand Up @@ -1615,6 +1620,7 @@ def compute_hash() -> str:
"VLLM_ROCM_QUICK_REDUCE_CAST_BF16_TO_FP16",
"VLLM_ROCM_QUICK_REDUCE_MAX_SIZE_BYTES_MB",
"VLLM_ROCM_FP8_MFMA_PAGE_ATTN",
"VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE",
"VLLM_ENABLE_INDUCTOR_MAX_AUTOTUNE",
"VLLM_ENABLE_INDUCTOR_COORDINATE_DESCENT_TUNING",
"VLLM_NVFP4_GEMM_BACKEND",
Expand Down
21 changes: 19 additions & 2 deletions vllm/model_executor/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch import nn
from transformers import Qwen3Config

import vllm.envs as envs
from vllm.attention import Attention, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
Expand All @@ -41,6 +42,7 @@
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsEagle3, SupportsLoRA, SupportsPP
Expand All @@ -49,6 +51,12 @@
from .utils import AutoWeightsLoader, PPMissingLayer, extract_layer_index, maybe_prefix

logger = init_logger(__name__)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AITER flags management are done in the _aiter_ops.py. Please move all the flags there and use rocm_aiter_ops.is_enabled() and some new flags there.

envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
Comment on lines +54 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:

from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE



class Qwen3Attention(nn.Module):
Expand Down Expand Up @@ -132,6 +140,11 @@ def __init__(
}
if dual_chunk_attention_config
else {},
rotary_emb=(
self.rotary_emb
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

else None
),
)
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
Expand All @@ -150,8 +163,12 @@ def forward(
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

attn_output = self.attn(q, k, v, positions=positions)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

output, _ = self.o_proj(attn_output)
return output

Expand Down
21 changes: 19 additions & 2 deletions vllm/model_executor/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import torch
from torch import nn

import vllm.envs as envs
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config
Expand Down Expand Up @@ -63,6 +64,7 @@
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.utils import sequence_parallel_chunk
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors

from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP
Expand All @@ -77,6 +79,12 @@
)

logger = init_logger(__name__)
if current_platform.is_rocm() and envs.VLLM_ROCM_USE_AITER:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
)
else:
VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE = False
Comment on lines +82 to +87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This logic for setting VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE is duplicated from vllm/attention/layer.py. To improve maintainability and avoid potential inconsistencies, this flag should be defined in a single location and imported where needed. Please remove this duplicated block and import the flag from vllm.attention.layer like so:

from vllm.attention.layer import VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE



class Qwen3MoeMLP(nn.Module):
Expand Down Expand Up @@ -291,6 +299,11 @@ def __init__(
}
if dual_chunk_attention_config
else {},
rotary_emb=(
self.rotary_emb
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

else None
),
)

self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps)
Expand All @@ -311,8 +324,12 @@ def forward(
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim)
k_by_head = self.k_norm(k_by_head)
k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
if VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

attn_output = self.attn(q, k, v, positions=positions)
else:
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)

output, _ = self.o_proj(attn_output)
return output

Expand Down
78 changes: 60 additions & 18 deletions vllm/v1/attention/backends/rocm_aiter_fa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -35,6 +36,9 @@

from vllm.triton_utils import tl, triton

if envs.VLLM_ROCM_USE_AITER_TRITON_FUSED_ROPE_ZEROS_KV_CACHE:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

likewise

from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache

def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

Expand Down Expand Up @@ -637,6 +641,7 @@ def forward(
output: torch.Tensor | None = None,
output_scale: torch.Tensor | None = None,
output_block_scale: torch.Tensor | None = None,
positions: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward pass with AiterFlashAttention.

Expand Down Expand Up @@ -675,25 +680,62 @@ def forward(
# performance to make sure it does not introduce any overhead.
num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping
# is not padded. However, we don't need to do
# key[:num_actual_tokens] and value[:num_actual_tokens] because
# the reshape_and_cache_flash op uses the slot_mapping's shape
# to determine the number of actual tokens.

torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
if positions is not None and query.shape[0] <= 256:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The value 256 is a magic number that determines the token threshold for using the fused kernel. It should be defined as a named constant, for example _FUSED_QK_ROPE_RESHAPE_AND_CACHE_MAX_TOKENS = 256, at the top of the file to improve readability and maintainability.

Suggested change
if positions is not None and query.shape[0] <= 256:
if positions is not None and query.shape[0] <= 256: # TODO: make this a constant

assert self.kv_sharing_target_layer_name is None, (
"self.kv_sharing_target_layer_name cannot be None"
)
assert hasattr(self, "rotary_emb"), f"rotary_emb not found in {self}"
cos, sin = self.rotary_emb.cos_sin_cache.chunk(2, dim=-1)
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8")
if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype())

query, key, key_cache, value_cache, output = (
fused_qk_rope_reshape_and_cache(
query,
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
positions,
cos,
sin,
layer._k_scale,
layer._v_scale,
self.rotary_emb.is_neox_style,
flash_layout=True,
apply_scale=is_fp8_kv_cache,
offs=None,
q_out=query,
k_out=key,
output_zeros=True,
zeros_out=output,
)
)
else:
if positions is not None:
query, key = self.rotary_emb(positions, query, key)

if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)

if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(current_platform.fp8_dtype())
Expand Down
Loading