diff --git a/tests/compile/fusions_e2e/test_tp2_ar_rms.py b/tests/compile/fusions_e2e/test_tp2_ar_rms.py index 4b0a0859b023..396cb520c436 100644 --- a/tests/compile/fusions_e2e/test_tp2_ar_rms.py +++ b/tests/compile/fusions_e2e/test_tp2_ar_rms.py @@ -88,6 +88,7 @@ def test_tp2_ar_rms_fp8_fusions( fuse_attn_quant=True, enable_qk_norm_rope_fusion=True, fuse_allreduce_rms=True, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split ), ) @@ -150,6 +151,7 @@ def test_tp2_ar_rms_fp4_fusions( fuse_act_quant=True, fuse_attn_quant=True, fuse_allreduce_rms=True, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split ), ) @@ -204,6 +206,7 @@ def test_tp2_ar_rms_fusions( pass_config=PassConfig( enable_qk_norm_rope_fusion=True, fuse_allreduce_rms=True, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split ), ) diff --git a/tests/compile/fusions_e2e/test_tp2_async_tp.py b/tests/compile/fusions_e2e/test_tp2_async_tp.py index 609377e68958..bfd30bf9f07d 100644 --- a/tests/compile/fusions_e2e/test_tp2_async_tp.py +++ b/tests/compile/fusions_e2e/test_tp2_async_tp.py @@ -71,6 +71,7 @@ def test_tp2_async_tp_fp8_fusions( enable_sp=True, fuse_gemm_comms=True, fuse_allreduce_rms=False, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split # Override threshold for testing (models have small hidden_size) sp_min_token_num=512, ), @@ -132,6 +133,7 @@ def test_tp2_async_tp_fusions( enable_sp=True, fuse_gemm_comms=True, fuse_allreduce_rms=False, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split # Override threshold for testing (models have small hidden_size) sp_min_token_num=512, ), @@ -197,6 +199,7 @@ def test_tp2_sp_ar_rms_fp8_fusions( enable_sp=True, fuse_gemm_comms=True, fuse_allreduce_rms=True, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split # Override threshold for testing (models have small hidden_size) sp_min_token_num=512, ), @@ -258,6 +261,7 @@ def test_tp2_sp_ar_rms_fusions( enable_sp=True, fuse_gemm_comms=True, fuse_allreduce_rms=True, + fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split # Override threshold for testing (models have small hidden_size) sp_min_token_num=512, ), diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index bab70c12a89b..eedb30f03096 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy import pytest import torch @@ -8,7 +9,7 @@ from tests.compile.backend import TestBackend from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops -from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP +from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS, ROTARY_OP from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass @@ -17,35 +18,41 @@ ) from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass from vllm.config import ( + AttentionConfig, CacheConfig, CompilationConfig, CompilationMode, ModelConfig, PassConfig, + SchedulerConfig, VllmConfig, + set_current_vllm_config, ) from vllm.forward_context import get_forward_context, set_forward_context from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8StaticTensorSym, +) from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer from vllm.utils.torch_utils import _encode_layer_name from vllm.v1.attention.backend import ( AttentionBackend, CommonAttentionMetadata, ) from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode INDEX_SELECT_OP = torch.ops.aten.index.Tensor VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update FP8_DTYPE = current_platform.fp8_dtype() -class QKRoPEKVCacheTestModel(torch.nn.Module): +class QKRoPEKVCacheTestModelBase(torch.nn.Module): def __init__( self, vllm_config: VllmConfig, - attn_backend: AttentionBackendEnum, num_heads: int, num_kv_heads: int, head_size: int, @@ -53,6 +60,7 @@ def __init__( dtype: torch.dtype, device: torch.device, prefix: str = "model.layers.0.self_attn.attn", + attn_backend: AttentionBackendEnum = None, ): super().__init__() self.num_heads = num_heads @@ -87,7 +95,7 @@ def __init__( cache_config=vllm_config.cache_config, quant_config=vllm_config.quant_config, prefix=prefix, - attn_backend=attn_backend.get_class(), + attn_backend=attn_backend.get_class() if attn_backend is not None else None, ) self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend() assert not self.attn_backend.forward_includes_kv_cache_update, ( @@ -96,18 +104,14 @@ def __init__( self.attn._k_scale = self.attn._k_scale.to(device) self.attn._v_scale = self.attn._v_scale.to(device) - kv_cache_dtype_str = vllm_config.cache_config.cache_dtype - self.kv_cache_dtype = ( - FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype - ) - # Initialize attn MetadataBuilder self.builder = self.attn.attn_backend.get_builder_cls()( kv_cache_spec=AttentionSpec( block_size=self.block_size, num_kv_heads=self.num_kv_heads, head_size=head_size, - dtype=self.kv_cache_dtype, + dtype=self.attn.kv_cache_torch_dtype, + kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype), ), layer_names=[self.attn.layer_name], vllm_config=vllm_config, @@ -143,7 +147,7 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: # Create dummy KV cache raw_tensor = torch.zeros( 2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size, - dtype=self.kv_cache_dtype, + dtype=self.attn.kv_cache_torch_dtype, device=self.device, ) raw_tensor = raw_tensor.view(kv_cache_shape) @@ -158,6 +162,19 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata: return attn_metadata + def forward( + self, qkv: torch.Tensor, positions: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + raise NotImplementedError + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + raise NotImplementedError + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + raise NotImplementedError + + +class QKRoPEKVCacheTestModel(QKRoPEKVCacheTestModelBase): def forward( self, qkv: torch.Tensor, positions: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: @@ -191,6 +208,39 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]: return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default] +class QKRoPEQuantKVCacheTestModel(QKRoPEKVCacheTestModelBase): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.attn.query_quant is not None + + def forward(self, qkv: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + qkv = qkv.clone() + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v) + return attn_output + + def ops_in_model_before(self) -> list[torch._ops.OpOverload]: + ops = [] + if self.enable_rope_custom_op: + if self.rotary_emb.use_flashinfer: + ops.append(torch.ops.vllm.flashinfer_rotary_embedding.default) + else: + ops.append(ROTARY_OP) + else: + ops.append(INDEX_SELECT_OP) + if self.attn.query_quant.enabled(): + ops.append(QUANT_OPS[kFp8StaticTensorSym]) + else: + ops.append(torch.ops.aten.reciprocal) + ops.append(torch.ops.vllm.unified_kv_cache_update.default) + return ops + + def ops_in_model_after(self) -> list[torch._ops.OpOverload]: + return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default] + + @pytest.mark.parametrize( "attn_backend", [ @@ -259,13 +309,13 @@ def test_rope_kvcache_fusion( model = QKRoPEKVCacheTestModel( vllm_config=vllm_config, - attn_backend=attn_backend, num_heads=num_heads, num_kv_heads=num_kv_heads, head_size=head_size, is_neox=is_neox, dtype=dtype, device=torch.get_default_device(), + attn_backend=attn_backend, ) fusion_pass = RopeKVCacheFusionPass(vllm_config) @@ -333,3 +383,156 @@ def test_rope_kvcache_fusion( atol=ATOL, rtol=RTOL, ) + + +@pytest.mark.parametrize("attn_backend", [AttentionBackendEnum.FLASHINFER]) +@pytest.mark.parametrize("model_name", ["openai/gpt-oss-20b"]) +@pytest.mark.parametrize("enable_rope_custom_op", [True]) +@pytest.mark.parametrize("enable_quant_custom_op", [True, False]) +@pytest.mark.parametrize("enable_flashinfer_rope", [True, False]) +@pytest.mark.parametrize("batch_size", [7, 64, 533]) +@pytest.mark.parametrize("num_heads", [64]) +@pytest.mark.parametrize("num_kv_heads", [8]) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("is_neox", [True, False]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8"]) +@pytest.mark.skipif( + not ( + current_platform.is_cuda() + and current_platform.is_device_capability((10, 0)) + and has_flashinfer() + ), + reason="Only test on CUDA Blackwell platform with FlashInfer installed", +) +def test_rope_quant_kvcache_fusion( + attn_backend: AttentionBackendEnum, + model_name: str, + enable_rope_custom_op: bool, + enable_quant_custom_op: bool, + enable_flashinfer_rope: bool, + batch_size: int, + num_heads: int, + num_kv_heads: int, + head_size: int, + block_size: int, + is_neox: bool, + dtype: torch.dtype, + kv_cache_dtype: str, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1") + if enable_flashinfer_rope: + monkeypatch.setenv("VLLM_USE_FLASHINFER_ROPE", "1") + + torch.set_default_device("cuda") + torch.set_default_dtype(dtype) + torch.manual_seed(42) + + custom_ops: list[str] = [] + if enable_rope_custom_op: + custom_ops.append("+rotary_embedding") + if enable_quant_custom_op: + custom_ops.append("+quant_fp8") + + model_config = ModelConfig( + model=model_name, + max_model_len=2048, + dtype=dtype, + ) + + vllm_config = VllmConfig( + model_config=model_config, + scheduler_config=SchedulerConfig( + max_num_seqs=1024, + max_model_len=model_config.max_model_len, + is_encoder_decoder=model_config.is_encoder_decoder, + ), + cache_config=CacheConfig( + block_size=block_size, + cache_dtype=kv_cache_dtype, + ), + attention_config=AttentionConfig( + backend=attn_backend, + ), + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig( + eliminate_noops=False, + fuse_rope_kvcache=False, + ), + ), + ) + + hidden_size = head_size * (num_heads + num_kv_heads * 2) + qkv = torch.randn(batch_size, hidden_size, dtype=dtype) + pos = torch.arange(batch_size, dtype=torch.long) + + # Run model directly without fusion + vllm_config_unfused = copy.deepcopy(vllm_config) + with ( + set_current_vllm_config(vllm_config_unfused), + set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused), + ): + model_unfused = QKRoPEQuantKVCacheTestModel( + vllm_config=vllm_config_unfused, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + is_neox=is_neox, + dtype=dtype, + device=torch.get_default_device(), + ) + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size) + forward_ctx.slot_mapping = { + model_unfused.layer_name: forward_ctx.attn_metadata.slot_mapping + } + compiled_unfused = torch.compile(model_unfused, fullgraph=True) + result_unfused = compiled_unfused(qkv.clone(), pos.clone()) + + # Run model with fusion enabled + vllm_config.compilation_config.pass_config = PassConfig( + eliminate_noops=True, + fuse_rope_kvcache=True, + ) + with ( + set_current_vllm_config(vllm_config), + set_forward_context(attn_metadata=None, vllm_config=vllm_config), + ): + model_fused = QKRoPEQuantKVCacheTestModel( + vllm_config=vllm_config, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_size=head_size, + is_neox=is_neox, + dtype=dtype, + device=torch.get_default_device(), + ) + forward_ctx = get_forward_context() + forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) + forward_ctx.slot_mapping = { + model_fused.layer_name: forward_ctx.attn_metadata.slot_mapping + } + + # Create test backend with fusion passes enabled + fusion_pass = RopeKVCacheFusionPass(vllm_config) + passes = [ + NoOpEliminationPass(vllm_config), + SplitCoalescingPass(vllm_config), + ScatterSplitReplacementPass(vllm_config), + fusion_pass, + PostCleanupPass(vllm_config), + ] + backend = TestBackend(*passes) + compiled_fused = torch.compile(model_fused, backend=backend, fullgraph=True) + result_fused = compiled_fused(qkv.clone(), pos.clone()) + + assert fusion_pass.matched_count == 1 + + backend.check_before_ops(model_fused.ops_in_model_before()) + backend.check_after_ops(model_fused.ops_in_model_after()) + + torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 91decf6658a5..4a1ff41c2a88 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -84,9 +84,15 @@ def create_common_attn_metadata( block_table_tensor = torch.arange( num_blocks, dtype=torch.int32, device=device ).view(batch_spec.batch_size, max_blocks) - slot_mapping = torch.arange(num_tokens, dtype=torch.int64, device=device).view( - num_tokens - ) + # Compute slot_mapping consistent with block_table: + slots = [] + for i in range(batch_spec.batch_size): + context_len = batch_spec.seq_lens[i] - batch_spec.query_lens[i] + for j in range(batch_spec.query_lens[i]): + global_pos = context_len + j + physical_block = block_table_tensor[i, global_pos // block_size].item() + slots.append(physical_block * block_size + global_pos % block_size) + slot_mapping = torch.tensor(slots, dtype=torch.int64, device=device) else: block_table_tensor = torch.randint( 0, diff --git a/vllm/compilation/passes/fusion/rope_kvcache_fusion.py b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py index bc6754188aa6..c97a6167933b 100644 --- a/vllm/compilation/passes/fusion/rope_kvcache_fusion.py +++ b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py @@ -15,6 +15,11 @@ Attention, get_attention_context, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + QuantKey, + kStaticTensorScale, +) +from vllm.platforms import current_platform from vllm.utils.torch_utils import ( _USE_LAYERNAME, LayerNameType, @@ -25,16 +30,13 @@ from ..inductor_pass import enable_fake_mode from ..vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass -from .matcher_utils import ( - MatcherRotaryEmbedding, -) -from .rms_quant_fusion import ( - empty_bf16, - empty_i64, -) +from .matcher_utils import MatcherQuantFP8, MatcherRotaryEmbedding +from .rms_quant_fusion import empty_bf16, empty_fp32, empty_i64 logger = init_logger(__name__) +FP8_DTYPE = current_platform.fp8_dtype() + def fused_rope_and_unified_kv_cache_update_impl( query: torch.Tensor, @@ -44,6 +46,8 @@ def fused_rope_and_unified_kv_cache_update_impl( cos_sin_cache: torch.Tensor, is_neox: bool, layer_name: LayerNameType, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ) -> torch.Tensor: """ This impl fetches the KV cache and slot mapping from the forward context, @@ -53,7 +57,9 @@ def fused_rope_and_unified_kv_cache_update_impl( the data dependency between them to ensure torch.compile preserves ordering. """ layer_name = _resolve_layer_name(layer_name) - _, attn_layer, kv_cache, layer_slot_mapping = get_attention_context(layer_name) + attn_metadata, attn_layer, kv_cache, layer_slot_mapping = get_attention_context( + layer_name + ) if layer_slot_mapping is not None: attn_layer.impl.do_rope_and_kv_cache_update( attn_layer, @@ -65,6 +71,9 @@ def fused_rope_and_unified_kv_cache_update_impl( is_neox, kv_cache, layer_slot_mapping, + attn_metadata, + query_quant_scale, + query_quant_out, ) return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) @@ -78,6 +87,8 @@ def fused_rope_and_unified_kv_cache_update_fake( cos_sin_cache: torch.Tensor, is_neox: bool, layer_name: LayerNameType, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ) -> torch.Tensor: return torch.empty(0, device=query.device, dtype=query.dtype) @@ -85,7 +96,7 @@ def fused_rope_and_unified_kv_cache_update_fake( direct_register_custom_op( op_name="fused_rope_and_unified_kv_cache_update", op_func=fused_rope_and_unified_kv_cache_update_impl, - mutates_args=["query", "key"], + mutates_args=["query", "key", "query_quant_out"], fake_impl=fused_rope_and_unified_kv_cache_update_fake, ) @@ -223,6 +234,178 @@ def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: ) +class RopeQuantReshapeKVCachePattern: + """ + This pattern matches the following unfused inplace ops: + q, k = rotary_embedding(positions, q, k, head_size, cos_sin_cache, is_neox) + q = static_scaled_fp8_quant(q, scale) + kv_cache_dummy = unified_kv_cache_update(k, v, layer_name) + + and replaces it with the fused inplace op: + kv_cache_dummy = fused_rope_and_unified_kv_cache_update( + q, k, v, positions, cos_sin_cache, is_neox, layer_name, scale + ) + """ + + FUSED_OP = torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default + + def __init__( + self, + layer: Attention, + quant_key: QuantKey, + is_neox: bool, + use_flashinfer_rope: bool = False, + ) -> None: + self.layer_name = layer.layer_name + self.num_heads = layer.num_heads + self.num_kv_heads = layer.num_kv_heads + self.head_size = layer.head_size + self.head_size_v = layer.head_size_v + self.is_neox = is_neox + self.use_flashinfer_rope = use_flashinfer_rope + + self.q_size = self.num_heads * self.head_size + self.k_size = self.num_kv_heads * self.head_size + self.v_size = self.num_kv_heads * self.head_size_v + + self.rope_matcher = MatcherRotaryEmbedding( + is_neox=self.is_neox, + head_size=self.head_size, + num_heads=self.num_heads, + num_kv_heads=self.num_kv_heads, + use_flashinfer=self.use_flashinfer_rope, + ) + self.quant_key = quant_key + self.quant_matcher = MatcherQuantFP8(quant_key) + + def get_inputs(self) -> list: + T = 5 + L = 4096 + qkv = empty_bf16(T, self.q_size + self.k_size + self.v_size) + positions = empty_i64(T) + cos_sin_cache = empty_fp32(L, self.head_size) + query_quant_scale = empty_fp32(1, 1) + inputs: list = [qkv, positions, cos_sin_cache, query_quant_scale] + if _USE_LAYERNAME: + inputs.append(_encode_layer_name(self.layer_name)) + return inputs + + def _mk_pattern_with_layer_name_input(self, _ln): + """Pattern/replacement with layer_name as an explicit input.""" + + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + query_quant_scale: torch.Tensor, + layer_name: LayerNameType, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q, k = self.rope_matcher(positions, q, k, cos_sin_cache) + q, _ = self.quant_matcher(q, query_quant_scale) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + dummy = torch.ops.vllm.unified_kv_cache_update(k, v, layer_name) + return dummy, q, k, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + query_quant_scale: torch.Tensor, + layer_name: LayerNameType, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + q_quant_out = torch.empty_like(q, dtype=self.quant_key.dtype) + results = auto_functionalized( + self.FUSED_OP, + query=q, + key=k, + value=v, + positions=positions, + cos_sin_cache=cos_sin_cache.to(torch.float32), + is_neox=self.is_neox, + layer_name=layer_name, + query_quant_scale=query_quant_scale, + query_quant_out=q_quant_out, + ) + return results[0], results[3], results[2], v + + return pattern, replacement + + def _mk_pattern_with_layer_name_closure(self, _ln): + """Pattern/replacement with layer_name as a closure constant.""" + + def pattern( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + query_quant_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q, k = self.rope_matcher(positions, q, k, cos_sin_cache) + q, _ = self.quant_matcher(q, query_quant_scale) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + dummy = torch.ops.vllm.unified_kv_cache_update(k, v, _ln) + return dummy, q, k, v + + def replacement( + qkv: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + query_quant_scale: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = qkv.split([self.q_size, self.k_size, self.v_size], dim=-1) + q = q.view(-1, self.num_heads, self.head_size) + k = k.view(-1, self.num_kv_heads, self.head_size) + v = v.view(-1, self.num_kv_heads, self.head_size_v) + q_quant_out = torch.empty_like(q, dtype=self.quant_key.dtype) + results = auto_functionalized( + self.FUSED_OP, + query=q, + key=k, + value=v, + positions=positions, + cos_sin_cache=cos_sin_cache.to(torch.float32), + is_neox=self.is_neox, + layer_name=_ln, + query_quant_scale=query_quant_scale, + query_quant_out=q_quant_out, + ) + return results[0], results[3], results[2], v + + return pattern, replacement + + def register(self, pm_pass: PatternMatcherPass) -> None: + _ln = _encode_layer_name(self.layer_name) + + if _USE_LAYERNAME: + pattern, replacement = self._mk_pattern_with_layer_name_input(_ln) + else: + pattern, replacement = self._mk_pattern_with_layer_name_closure(_ln) + + # NOTE: use view_to_reshape to unify view/reshape to simplify + # pattern and increase matching opportunities + def fwd_and_view_to_reshape(*args, **kwargs) -> fx.GraphModule: + gm = pm.fwd_only(*args, **kwargs) + view_to_reshape(gm) + return gm + + pm.register_replacement( + pattern, + replacement, + self.get_inputs(), + fwd_and_view_to_reshape, + pm_pass, + ) + + class RopeKVCacheFusionPass(VllmPatternMatcherPass): """ This pass fuses the rotary embedding and KV cache update operations @@ -247,18 +430,36 @@ def __init__(self, config: VllmConfig) -> None: cc = config.compilation_config self.max_token_num = cc.pass_config.rope_kvcache_fusion_max_token_num + quant_key = None + # Only CUDA supports RoPE + Query Quant + KV Cache Pattern + if current_platform.is_cuda(): + quant_key = QuantKey( + dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=True + ) + attn_layers = get_layers_from_vllm_config(config, Attention) # When _USE_LAYERNAME is enabled, layer_name is a wildcard so all # layers produce the same pattern — register once then break. for _, layer in attn_layers.items(): - if layer.impl.fused_rope_kvcache_supported(): - for is_neox in [True, False]: + if not layer.impl.fused_rope_kvcache_supported(quant_key): + continue + for is_neox in [True, False]: + if current_platform.is_cuda(): + for use_flashinfer_rope in [True, False]: + assert quant_key is not None + RopeQuantReshapeKVCachePattern( + layer=layer, + quant_key=quant_key, + is_neox=is_neox, + use_flashinfer_rope=use_flashinfer_rope, + ).register(self.patterns) + elif current_platform.is_rocm(): RopeReshapeKVCachePattern( layer=layer, is_neox=is_neox, ).register(self.patterns) - if _USE_LAYERNAME: - break + if _USE_LAYERNAME: + break self.dump_patterns(config, self.patterns) @@ -274,4 +475,6 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: return compile_range.end <= self.max_token_num def uuid(self) -> str: - return VllmInductorPass.hash_source(self, RopeReshapeKVCachePattern) + return VllmInductorPass.hash_source( + self, RopeReshapeKVCachePattern, RopeQuantReshapeKVCachePattern + ) diff --git a/vllm/compilation/passes/utility/fix_functionalization.py b/vllm/compilation/passes/utility/fix_functionalization.py index 15eb23e6f949..380c3c1fd5ae 100644 --- a/vllm/compilation/passes/utility/fix_functionalization.py +++ b/vllm/compilation/passes/utility/fix_functionalization.py @@ -44,6 +44,8 @@ def __call__(self, graph: torch.fx.Graph) -> None: rope_targets.append( torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default ) + if hasattr(torch.ops.vllm, "flashinfer_rotary_embedding"): + rope_targets.append(torch.ops.vllm.flashinfer_rotary_embedding.default) for node in graph.nodes: if not is_func(node, auto_functionalized): @@ -179,6 +181,7 @@ def __call__(self, graph: torch.fx.Graph) -> None: mutated_args = { 1: "query", 2: "key", + 3: "query_quant_out", } self.defunctionalize(graph, node, mutated_args=mutated_args) # only used for test_functionalization::TestFunctionWithMutatedArgsAndReturn diff --git a/vllm/compilation/passes/utility/scatter_split_replace.py b/vllm/compilation/passes/utility/scatter_split_replace.py index a17a7b336d2d..88974fe3a7a2 100644 --- a/vllm/compilation/passes/utility/scatter_split_replace.py +++ b/vllm/compilation/passes/utility/scatter_split_replace.py @@ -63,6 +63,8 @@ def __call__(self, graph: fx.Graph) -> None: target_ops = [torch.ops._C.rotary_embedding.default] if hasattr(torch.ops.vllm, "rocm_aiter_triton_rotary_embedding"): target_ops.append(torch.ops.vllm.rocm_aiter_triton_rotary_embedding.default) + if hasattr(torch.ops.vllm, "flashinfer_rotary_embedding"): + target_ops.append(torch.ops.vllm.flashinfer_rotary_embedding.default) for node in graph.nodes: if not is_func(node, auto_functionalized): diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 11c933fc72f5..12126bc6deb2 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -138,15 +138,15 @@ class PassConfig: """Enable fused allreduce+RMSNorm for MiniMax QK norm.""" enable_qk_norm_rope_fusion: bool = False """Enable fused Q/K RMSNorm + RoPE pass.""" + fuse_rope_kvcache: bool = None # type: ignore[assignment] + """Fuse the QK rope (+ Q quant) + KV cache ops.""" # ROCm/AITER specific fusions fuse_act_padding: bool = None # type: ignore[assignment] """Fuse the custom RMSNorm + padding ops.""" - fuse_rope_kvcache: bool = None # type: ignore[assignment] - """Fuse the QK rope + KV cache ops.""" rope_kvcache_fusion_max_token_num: int = 256 - """The threshold for ROCm AITER RoPE+KVCache fusion e.g. for small batch decode. + """The threshold for RoPE(+Q quant)+KVCache fusion e.g. for small batch decode. Larger batch sizes e.g. during prefill will use the unfused kernels. """ @@ -270,10 +270,10 @@ def __post_init__(self) -> None: "The fusion will be disabled." ) self.fuse_act_padding = False - if self.fuse_rope_kvcache and not current_platform.is_rocm(): + if self.fuse_rope_kvcache and not current_platform.is_cuda_alike(): logger.warning_once( - "KV cache fusion currently only enabled on ROCm. " - "The fusion will be disabled." + "KV cache fusion enabled but the current platform is not " + "CUDA or ROCm. The fusion will be disabled." ) self.fuse_rope_kvcache = False @@ -1069,6 +1069,20 @@ def set_splitting_ops_for_v1( # list via reference. self.splitting_ops = list(self._attention_ops) + # Flashinfer fuse_rope_kvcache op needs to be a splitting op in + # piecewise cudagraph since it needs to access attn_metadata. + from vllm.utils.flashinfer import has_flashinfer + + if ( + current_platform.is_cuda() + and has_flashinfer() + and self.use_inductor_graph_partition + and self.pass_config.fuse_rope_kvcache + ): + self.splitting_ops.append( + "vllm::fused_rope_and_unified_kv_cache_update" + ) + # unified_kv_cache_update has a string param that prevents Inductor # from reusing piecewise graphs. Remove it from the compiled graph. # This has the side-effect of excluding cache from cudagraphs but diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7bf75a67b2ce..5d9207644073 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -142,9 +142,12 @@ def enable_rope_kvcache_fusion(cfg: "VllmConfig") -> bool: use_inductor_graph_partition is enabled. """ from vllm._aiter_ops import rocm_aiter_ops + from vllm.platforms import current_platform + from vllm.utils.flashinfer import has_flashinfer return ( - rocm_aiter_ops.is_enabled() + current_platform.is_cuda_alike() + and (rocm_aiter_ops.is_enabled() or has_flashinfer()) and cfg.compilation_config.is_custom_op_enabled("rotary_embedding") and ( cfg.compilation_config.use_inductor_graph_partition diff --git a/vllm/envs.py b/vllm/envs.py index ee9d006aa987..be29c671e7cc 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -258,6 +258,7 @@ VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 VLLM_XPU_ENABLE_XPU_GRAPH: bool = False + VLLM_USE_FLASHINFER_ROPE: bool = False VLLM_LORA_ENABLE_DUAL_STREAM: bool = False @@ -1714,6 +1715,10 @@ def _get_or_set_default() -> str: "VLLM_USE_SIMPLE_KV_OFFLOAD": lambda: bool( int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0")) ), + # If set to 1, use the FlashInfer's rotary embedding kernel + "VLLM_USE_FLASHINFER_ROPE": lambda: bool( + int(os.getenv("VLLM_USE_FLASHINFER_ROPE", "0")) + ), # Whether to enable dual cuda streams for LoRA computation "VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool( int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0")) diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 1374334b2cad..7923e9b2eec8 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,8 +4,11 @@ import torch +from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.model_executor.custom_op import CustomOp +from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer from .common import ApplyRotaryEmb @@ -34,18 +37,20 @@ def __init__( self.base = base self.is_neox_style = is_neox_style self.dtype = dtype - # TODO(mgoin): disabled for now due to failures - # Flashinfer only supports head_size=64, 128, 256, 512. - # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 - # self.use_flashinfer = (self.enabled() - # and dtype in (torch.float16, torch.bfloat16) - # and current_platform.is_cuda() - # and has_flashinfer() - # and self.head_size in [64, 128, 256, 512]) # Check if use_flashinfer is already set if not hasattr(self, "use_flashinfer"): - self.use_flashinfer = False + # TODO(mgoin): VLLM_USE_FLASHINFER_ROPE is disabled for now due to failures + # Flashinfer only supports head_size=64, 128, 256, 512. + # https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202 + self.use_flashinfer = ( + self.enabled() + and dtype in (torch.float16, torch.bfloat16) + and current_platform.is_cuda() + and has_flashinfer() + and self.head_size in [64, 128, 256, 512] + and envs.VLLM_USE_FLASHINFER_ROPE + ) self.use_aiter = ( self.enabled() and rocm_aiter_ops.is_triton_rotary_embed_enabled() diff --git a/vllm/model_executor/models/gpt_oss.py b/vllm/model_executor/models/gpt_oss.py index a9ec82974227..cd2ca9806967 100644 --- a/vllm/model_executor/models/gpt_oss.py +++ b/vllm/model_executor/models/gpt_oss.py @@ -83,7 +83,6 @@ def __init__( self.rotary_emb = get_rope( self.head_dim, max_position=config.max_position_embeddings, - dtype=torch.float32, rope_parameters={ "rope_theta": config.rope_parameters["rope_theta"], "rope_type": "yarn", diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index 28d077fcb771..6d44467a3aa7 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -421,7 +421,7 @@ def seq_lens_cpu(self) -> torch.Tensor: @deprecated( """ Prefer using device seq_lens directly to avoid implicit H<>D sync which breaks full - async scheduling. If a CPU copy is needed, it can be derived from + async scheduling. If a CPU copy is needed, it can be derived from query_start_loc_cpu and seq_lens. Will be removed in a future release, please migrate as soon as possible. """ @@ -792,9 +792,9 @@ def fused_output_quant_supported(self, quant_key: "QuantKey"): """ return False - def fused_rope_kvcache_supported(self): + def fused_rope_kvcache_supported(self, query_quant_key: "QuantKey | None" = None): """ - Does this attention implementation support RoPE+KVCache fusion. + Does this attention implementation support RoPE(+Quant)+KVCache fusion. This is used by the RopeKVCacheFusionPass to only fuse the RoPE ops with the KV cache update for implementations that support it. """ @@ -811,6 +811,9 @@ def do_rope_and_kv_cache_update( is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, + attn_metadata: T | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ): """ If `fused_rope_kvcache_supported` returns True, this method will be called diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 4adad61c2fab..611bf01dd3d8 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -15,7 +15,9 @@ MultiLevelCascadeAttentionWrapper, ) from flashinfer.decode import fast_decode_plan, trtllm_batch_decode_with_kv_cache +from flashinfer.page import get_batch_indices_positions from flashinfer.prefill import trtllm_batch_context_with_kv_cache +from flashinfer.rope import rope_quantize_fp8_append_paged_kv_cache from flashinfer.utils import FP4Tensor from typing_extensions import override @@ -523,6 +525,16 @@ class FlashInferMetadata: cascade_wrapper: MultiLevelCascadeAttentionWrapper | None + # --- For RoPE + FP8 Quantize + KV Cache Update Kernel --- + paged_kv_indices: torch.Tensor | None = None + """Physical page indices for paged KV cache.""" + paged_kv_indptr: torch.Tensor | None = None + """Cumulative page count per request.""" + batch_indices: torch.Tensor | None = None + """Request index for each token.""" + paged_positions: torch.Tensor | None = None + """Position within each request's KV sequence.""" + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): reorder_batch_threshold: int = 1 @@ -666,6 +678,24 @@ def __init__( self.paged_kv_indices = self._make_buffer(max_num_pages) self.paged_kv_last_page_len = self._make_buffer(max_num_reqs) + self.enabled_rope_quant_cache_fusion = ( + self.compilation_config.pass_config.fuse_rope_kvcache + and self.cache_dtype.startswith("fp8") + and can_use_trtllm + ) + if self.enabled_rope_quant_cache_fusion: + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + self.batch_indices = torch.empty( + max_num_tokens, + device=device, + dtype=torch.int32, + ) + self.paged_positions = torch.empty( + max_num_tokens, + device=device, + dtype=torch.int32, + ) + def _make_buffer( self, *size: int | torch.SymInt, dtype: torch.dtype = torch.int32 ) -> CpuGpuBuffer: @@ -945,7 +975,12 @@ def build( # Guard access to seq_lens_cpu, which may not always be needed # and can be expensive to retrieve in async mode. - needs_seq_lens_cpu = self.use_dcp or use_cascade or not is_only_trtllm_decode + needs_seq_lens_cpu = ( + self.use_dcp + or use_cascade + or not is_only_trtllm_decode + or self.enabled_rope_quant_cache_fusion + ) seq_lens_cpu = common_attn_metadata.seq_lens_cpu if needs_seq_lens_cpu else None seq_lens_np = seq_lens_cpu.numpy() if seq_lens_cpu is not None else None num_blocks_np = ( @@ -983,7 +1018,11 @@ def build( num_blocks_np -= num_common_kv_blocks # Compute paged_kv_indices if necessary - needs_paged_kv_indices = use_cascade or not is_only_trtllm_decode + needs_paged_kv_indices = ( + use_cascade + or not is_only_trtllm_decode + or self.enabled_rope_quant_cache_fusion + ) if needs_paged_kv_indices: assert num_blocks_np is not None assert seq_lens_np is not None @@ -1184,6 +1223,23 @@ def build( disable_split_kv=self.disable_split_kv, ) attn_metadata.decode = FIDecode(wrapper=decode_wrapper) + + # Step 4: Pre-compute params for RoPE + FP8 quantize + KV cache update fusion + # kernel here to avoid per-layer computation in do_rope_and_kv_cache_update. + if self.enabled_rope_quant_cache_fusion: + assert paged_kv_indices is not None + attn_metadata.paged_kv_indices = paged_kv_indices + attn_metadata.paged_kv_indptr = self.paged_kv_indptr.gpu[: num_reqs + 1] + attn_metadata.batch_indices = self.batch_indices[:num_actual_tokens] + attn_metadata.paged_positions = self.paged_positions[:num_actual_tokens] + get_batch_indices_positions( + qo_indptr[: num_reqs + 1], + seq_lens[:num_reqs], + num_actual_tokens, + attn_metadata.batch_indices, + attn_metadata.paged_positions, + ) + return attn_metadata def use_cascade_attention(self, *args, **kwargs) -> bool: @@ -1659,24 +1715,105 @@ def do_kv_cache_update( kv_cache: torch.Tensor, slot_mapping: torch.Tensor, ) -> None: - if self.kv_sharing_target_layer_name is None: - # Reshape the input keys and values and store them in the cache. - # Skip this if sharing KV cache with an earlier attention layer. - # NOTE(woosuk): Here, key and value are padded while slot_mapping is - # not padded. However, we don't need to do key[:num_actual_tokens] - # and value[:num_actual_tokens] because the reshape_and_cache_flash - # op uses the slot_mapping's shape to determine the number of - # actual tokens. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - slot_mapping, - self.kv_cache_dtype, - layer._k_scale, - layer._v_scale, - ) + # Reshape the input keys and values and store them in the cache. + # Skip this if sharing KV cache with an earlier attention layer. + # NOTE(woosuk): Here, key and value are padded while slot_mapping is + # not padded. However, we don't need to do key[:num_actual_tokens] + # and value[:num_actual_tokens] because the reshape_and_cache_flash + # op uses the slot_mapping's shape to determine the number of + # actual tokens. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + slot_mapping, + self.kv_cache_dtype, + layer._k_scale, + layer._v_scale, + ) + + def fused_rope_kvcache_supported(self, query_quant_key: QuantKey | None = None): + return ( + self.support_trtllm_attn + and self.kv_cache_dtype.startswith("fp8") + and query_quant_key == kFp8StaticTensorSym + ) + + def do_rope_and_kv_cache_update( + self, + layer: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + is_neox: bool, + kv_cache: torch.Tensor, + layer_slot_mapping: torch.Tensor, + attn_metadata: FlashInferMetadata | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, + ): + if attn_metadata is None: + # Skip this in piecewise cudagraph capturing since the kernel requires + # access to the attn_metadata + return + + assert cos_sin_cache.dtype == torch.float32 + assert query_quant_scale is not None + assert query_quant_out is not None + + quant_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype + ) + kv_cache = kv_cache.view(quant_dtype) + + stride_order = FlashInferBackend.get_kv_cache_stride_order() + kv_cache_perm = kv_cache.permute(*stride_order) + k_cache = kv_cache_perm[:, 0] + v_cache = kv_cache_perm[:, 1] + + kv_layout = get_kv_cache_layout() + page_size = k_cache.shape[1] if kv_layout == "NHD" else k_cache.shape[2] + + rotary_dim = cos_sin_cache.shape[-1] + head_size = query.shape[-1] + + q_rope = query[..., :rotary_dim] + k_rope = key[..., :rotary_dim] + q_rope_out = query_quant_out[..., :rotary_dim] + if rotary_dim < head_size: + q_nope = query[..., rotary_dim:] + k_nope = key[..., rotary_dim:] + q_nope_out = query_quant_out[..., rotary_dim:] + else: + q_nope = None + k_nope = None + q_nope_out = None + + rope_quantize_fp8_append_paged_kv_cache( + q_rope=q_rope, + k_rope=k_rope, + q_nope=q_nope, + k_nope=k_nope, + v=value, + cos_sin_cache=cos_sin_cache, + pos_ids=positions, + paged_kv_cache=(k_cache, v_cache), + kv_indices=attn_metadata.paged_kv_indices, + kv_indptr=attn_metadata.paged_kv_indptr, + batch_indices=attn_metadata.batch_indices, + positions=attn_metadata.paged_positions, + is_neox=is_neox, + quantize_dtype=quant_dtype, + quant_scale_q=layer._q_scale_float, + quant_scale_kv=layer._k_scale_float, + page_size=page_size, + kv_layout=kv_layout, + q_rope_out=q_rope_out, + q_nope_out=q_nope_out, + ) def fast_plan_decode( diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index a4cdce3f5d3e..9314402801ed 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -12,6 +12,7 @@ from vllm.config.cache import CacheDType from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention +from vllm.model_executor.layers.quantization.utils.quant_utils import QuantKey from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv @@ -1411,7 +1412,7 @@ def do_kv_cache_update( layer._v_scale, ) - def fused_rope_kvcache_supported(self): + def fused_rope_kvcache_supported(self, query_quant_key: QuantKey | None = None): # Only support fusion when shuffle KV cache layout is not used; # shuffle layout uses a different cache update path. return ( @@ -1430,6 +1431,9 @@ def do_rope_and_kv_cache_update( is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, + attn_metadata: AiterFlashAttentionMetadata | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ): key_cache, value_cache = kv_cache.unbind(0) flash_layout = True diff --git a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py index eb0fe046e343..e1b3ec58253a 100644 --- a/vllm/v1/attention/backends/rocm_aiter_unified_attn.py +++ b/vllm/v1/attention/backends/rocm_aiter_unified_attn.py @@ -274,7 +274,7 @@ def do_kv_cache_update( layer._v_scale, ) - def fused_rope_kvcache_supported(self): + def fused_rope_kvcache_supported(self, query_quant_key: QuantKey | None = None): return rocm_aiter_ops.is_enabled() def do_rope_and_kv_cache_update( @@ -288,6 +288,9 @@ def do_rope_and_kv_cache_update( is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, + attn_metadata: FlashAttentionMetadata | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): # For encoder attention, diff --git a/vllm/v1/attention/backends/rocm_attn.py b/vllm/v1/attention/backends/rocm_attn.py index 3a906233272a..89baca1ac49f 100644 --- a/vllm/v1/attention/backends/rocm_attn.py +++ b/vllm/v1/attention/backends/rocm_attn.py @@ -487,7 +487,7 @@ def do_kv_cache_update( layer._v_scale, ) - def fused_rope_kvcache_supported(self): + def fused_rope_kvcache_supported(self, query_quant_key: QuantKey | None = None): return rocm_aiter_ops.is_enabled() def do_rope_and_kv_cache_update( @@ -501,6 +501,9 @@ def do_rope_and_kv_cache_update( is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, + attn_metadata: FlashAttentionMetadata | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ): if self.attn_type in (AttentionType.ENCODER_ONLY, AttentionType.ENCODER): return diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 76cae14aedb1..5aae4b965921 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -726,7 +726,7 @@ def do_kv_cache_update( layer._v_scale, ) - def fused_rope_kvcache_supported(self): + def fused_rope_kvcache_supported(self, query_quant_key: QuantKey | None = None): if self._is_per_token_head_quant: return False return rocm_aiter_ops.is_enabled() @@ -742,6 +742,9 @@ def do_rope_and_kv_cache_update( is_neox: bool, kv_cache: torch.Tensor, layer_slot_mapping: torch.Tensor, + attn_metadata: TritonAttentionMetadata | None = None, + query_quant_scale: torch.Tensor | None = None, + query_quant_out: torch.Tensor | None = None, ): key_cache, value_cache = kv_cache.unbind(1) flash_layout = True