Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/compile/fusions_e2e/test_tp2_ar_rms.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def test_tp2_ar_rms_fp8_fusions(
fuse_attn_quant=True,
enable_qk_norm_rope_fusion=True,
fuse_allreduce_rms=True,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of disabling the rope-cache fusion in tests, can we adjust the compile range logic?

),
)

Expand Down Expand Up @@ -150,6 +151,7 @@ def test_tp2_ar_rms_fp4_fusions(
fuse_act_quant=True,
fuse_attn_quant=True,
fuse_allreduce_rms=True,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
),
)

Expand Down Expand Up @@ -204,6 +206,7 @@ def test_tp2_ar_rms_fusions(
pass_config=PassConfig(
enable_qk_norm_rope_fusion=True,
fuse_allreduce_rms=True,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
),
)

Expand Down
4 changes: 4 additions & 0 deletions tests/compile/fusions_e2e/test_tp2_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def test_tp2_async_tp_fp8_fusions(
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
Expand Down Expand Up @@ -132,6 +133,7 @@ def test_tp2_async_tp_fusions(
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
Expand Down Expand Up @@ -197,6 +199,7 @@ def test_tp2_sp_ar_rms_fp8_fusions(
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
Expand Down Expand Up @@ -258,6 +261,7 @@ def test_tp2_sp_ar_rms_fusions(
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=True,
fuse_rope_kvcache=False, # FIXME: disable to avoid compile range split
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
Expand Down
229 changes: 216 additions & 13 deletions tests/compile/passes/test_rope_kvcache_fusion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy

import pytest
import torch
Expand All @@ -8,7 +9,7 @@
from tests.compile.backend import TestBackend
from tests.v1.attention.utils import BatchSpec, create_common_attn_metadata
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import ROTARY_OP
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS, ROTARY_OP
from vllm.compilation.passes.fusion.rope_kvcache_fusion import RopeKVCacheFusionPass
from vllm.compilation.passes.utility.noop_elimination import NoOpEliminationPass
from vllm.compilation.passes.utility.post_cleanup import PostCleanupPass
Expand All @@ -17,42 +18,49 @@
)
from vllm.compilation.passes.utility.split_coalescing import SplitCoalescingPass
from vllm.config import (
AttentionConfig,
CacheConfig,
CompilationConfig,
CompilationMode,
ModelConfig,
PassConfig,
SchedulerConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.attention import Attention
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer
from vllm.utils.torch_utils import _encode_layer_name
from vllm.v1.attention.backend import (
AttentionBackend,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.kv_cache_interface import AttentionSpec, get_kv_quant_mode

INDEX_SELECT_OP = torch.ops.aten.index.Tensor
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
FP8_DTYPE = current_platform.fp8_dtype()


class QKRoPEKVCacheTestModel(torch.nn.Module):
class QKRoPEKVCacheTestModelBase(torch.nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
attn_backend: AttentionBackendEnum,
num_heads: int,
num_kv_heads: int,
head_size: int,
is_neox: bool,
dtype: torch.dtype,
device: torch.device,
prefix: str = "model.layers.0.self_attn.attn",
attn_backend: AttentionBackendEnum = None,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this ever None?

):
super().__init__()
self.num_heads = num_heads
Expand Down Expand Up @@ -87,7 +95,7 @@ def __init__(
cache_config=vllm_config.cache_config,
quant_config=vllm_config.quant_config,
prefix=prefix,
attn_backend=attn_backend.get_class(),
attn_backend=attn_backend.get_class() if attn_backend is not None else None,
)
self.attn_backend: type[AttentionBackend] = self.attn.get_attn_backend()
assert not self.attn_backend.forward_includes_kv_cache_update, (
Expand All @@ -96,18 +104,14 @@ def __init__(
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)

kv_cache_dtype_str = vllm_config.cache_config.cache_dtype
self.kv_cache_dtype = (
FP8_DTYPE if kv_cache_dtype_str.startswith("fp8") else self.dtype
)

# Initialize attn MetadataBuilder
self.builder = self.attn.attn_backend.get_builder_cls()(
kv_cache_spec=AttentionSpec(
block_size=self.block_size,
num_kv_heads=self.num_kv_heads,
head_size=head_size,
dtype=self.kv_cache_dtype,
dtype=self.attn.kv_cache_torch_dtype,
kv_quant_mode=get_kv_quant_mode(self.attn.kv_cache_dtype),
),
layer_names=[self.attn.layer_name],
vllm_config=vllm_config,
Expand Down Expand Up @@ -143,7 +147,7 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:
# Create dummy KV cache
raw_tensor = torch.zeros(
2 * num_blocks * self.block_size * self.num_kv_heads * self.head_size,
dtype=self.kv_cache_dtype,
dtype=self.attn.kv_cache_torch_dtype,
device=self.device,
)
raw_tensor = raw_tensor.view(kv_cache_shape)
Expand All @@ -158,6 +162,19 @@ def build_attn_metadata(self, batch_size: int) -> CommonAttentionMetadata:

return attn_metadata

def forward(
self, qkv: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
raise NotImplementedError

def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
raise NotImplementedError

def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
raise NotImplementedError


class QKRoPEKVCacheTestModel(QKRoPEKVCacheTestModelBase):
def forward(
self, qkv: torch.Tensor, positions: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -191,6 +208,39 @@ def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]


class QKRoPEQuantKVCacheTestModel(QKRoPEKVCacheTestModelBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

assert self.attn.query_quant is not None

def forward(self, qkv: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
qkv = qkv.clone()
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
return attn_output

def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops = []
if self.enable_rope_custom_op:
if self.rotary_emb.use_flashinfer:
ops.append(torch.ops.vllm.flashinfer_rotary_embedding.default)
else:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
if self.attn.query_quant.enabled():
ops.append(QUANT_OPS[kFp8StaticTensorSym])
else:
ops.append(torch.ops.aten.reciprocal)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
return ops

def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
return [torch.ops.vllm.fused_rope_and_unified_kv_cache_update.default]


@pytest.mark.parametrize(
"attn_backend",
[
Expand Down Expand Up @@ -259,13 +309,13 @@ def test_rope_kvcache_fusion(

model = QKRoPEKVCacheTestModel(
vllm_config=vllm_config,
attn_backend=attn_backend,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
is_neox=is_neox,
dtype=dtype,
device=torch.get_default_device(),
attn_backend=attn_backend,
)

fusion_pass = RopeKVCacheFusionPass(vllm_config)
Expand Down Expand Up @@ -333,3 +383,156 @@ def test_rope_kvcache_fusion(
atol=ATOL,
rtol=RTOL,
)


@pytest.mark.parametrize("attn_backend", [AttentionBackendEnum.FLASHINFER])
@pytest.mark.parametrize("model_name", ["openai/gpt-oss-20b"])
@pytest.mark.parametrize("enable_rope_custom_op", [True])
@pytest.mark.parametrize("enable_quant_custom_op", [True, False])
@pytest.mark.parametrize("enable_flashinfer_rope", [True, False])
@pytest.mark.parametrize("batch_size", [7, 64, 533])
@pytest.mark.parametrize("num_heads", [64])
@pytest.mark.parametrize("num_kv_heads", [8])
@pytest.mark.parametrize("head_size", [64])
@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("is_neox", [True, False])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("kv_cache_dtype", ["fp8"])
@pytest.mark.skipif(
not (
current_platform.is_cuda()
and current_platform.is_device_capability((10, 0))
and has_flashinfer()
),
reason="Only test on CUDA Blackwell platform with FlashInfer installed",
)
def test_rope_quant_kvcache_fusion(
attn_backend: AttentionBackendEnum,
model_name: str,
enable_rope_custom_op: bool,
enable_quant_custom_op: bool,
enable_flashinfer_rope: bool,
batch_size: int,
num_heads: int,
num_kv_heads: int,
head_size: int,
block_size: int,
is_neox: bool,
dtype: torch.dtype,
kv_cache_dtype: str,
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setenv("VLLM_DISABLE_COMPILE_CACHE", "1")
if enable_flashinfer_rope:
monkeypatch.setenv("VLLM_USE_FLASHINFER_ROPE", "1")

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(42)

custom_ops: list[str] = []
if enable_rope_custom_op:
custom_ops.append("+rotary_embedding")
if enable_quant_custom_op:
custom_ops.append("+quant_fp8")

model_config = ModelConfig(
model=model_name,
max_model_len=2048,
dtype=dtype,
)

vllm_config = VllmConfig(
model_config=model_config,
scheduler_config=SchedulerConfig(
max_num_seqs=1024,
max_model_len=model_config.max_model_len,
is_encoder_decoder=model_config.is_encoder_decoder,
),
cache_config=CacheConfig(
block_size=block_size,
cache_dtype=kv_cache_dtype,
),
attention_config=AttentionConfig(
backend=attn_backend,
),
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(
eliminate_noops=False,
fuse_rope_kvcache=False,
),
),
)

hidden_size = head_size * (num_heads + num_kv_heads * 2)
qkv = torch.randn(batch_size, hidden_size, dtype=dtype)
pos = torch.arange(batch_size, dtype=torch.long)

# Run model directly without fusion
vllm_config_unfused = copy.deepcopy(vllm_config)
with (
set_current_vllm_config(vllm_config_unfused),
set_forward_context(attn_metadata=None, vllm_config=vllm_config_unfused),
):
model_unfused = QKRoPEQuantKVCacheTestModel(
vllm_config=vllm_config_unfused,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
is_neox=is_neox,
dtype=dtype,
device=torch.get_default_device(),
)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata(batch_size)
forward_ctx.slot_mapping = {
model_unfused.layer_name: forward_ctx.attn_metadata.slot_mapping
}
compiled_unfused = torch.compile(model_unfused, fullgraph=True)
result_unfused = compiled_unfused(qkv.clone(), pos.clone())

# Run model with fusion enabled
vllm_config.compilation_config.pass_config = PassConfig(
eliminate_noops=True,
fuse_rope_kvcache=True,
)
with (
set_current_vllm_config(vllm_config),
set_forward_context(attn_metadata=None, vllm_config=vllm_config),
):
model_fused = QKRoPEQuantKVCacheTestModel(
vllm_config=vllm_config,
num_heads=num_heads,
num_kv_heads=num_kv_heads,
head_size=head_size,
is_neox=is_neox,
dtype=dtype,
device=torch.get_default_device(),
)
forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size)
forward_ctx.slot_mapping = {
model_fused.layer_name: forward_ctx.attn_metadata.slot_mapping
}

# Create test backend with fusion passes enabled
fusion_pass = RopeKVCacheFusionPass(vllm_config)
passes = [
NoOpEliminationPass(vllm_config),
SplitCoalescingPass(vllm_config),
ScatterSplitReplacementPass(vllm_config),
fusion_pass,
PostCleanupPass(vllm_config),
]
backend = TestBackend(*passes)
compiled_fused = torch.compile(model_fused, backend=backend, fullgraph=True)
result_fused = compiled_fused(qkv.clone(), pos.clone())

assert fusion_pass.matched_count == 1

backend.check_before_ops(model_fused.ops_in_model_before())
backend.check_after_ops(model_fused.ops_in_model_after())

torch.testing.assert_close(result_unfused, result_fused, atol=1e-2, rtol=1e-2)
Loading
Loading