Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions tests/compile/passes/test_rope_kvcache_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
from vllm.v1.kv_cache_interface import AttentionSpec

INDEX_SELECT_OP = torch.ops.aten.index.Tensor
VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update
FP8_DTYPE = current_platform.fp8_dtype()


Expand Down Expand Up @@ -169,8 +168,8 @@ def forward(
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size)
kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update(
k, v, self.layer_name
q, k, kv_cache_dummy_dep = torch.ops.vllm.apply_kv_cache_update(
q, k, v, self.layer_name
)
return q, k, v, kv_cache_dummy_dep

Expand All @@ -183,7 +182,7 @@ def ops_in_model_before(self) -> list[torch._ops.OpOverload]:
ops.append(ROTARY_OP)
else:
ops.append(INDEX_SELECT_OP)
ops.append(torch.ops.vllm.unified_kv_cache_update.default)
ops.append(torch.ops.vllm.apply_kv_cache_update.default)
return ops

def ops_in_model_after(self) -> list[torch._ops.OpOverload]:
Expand Down
176 changes: 156 additions & 20 deletions tests/quantization/test_hadamard_kv_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
"""
Tests for CompressedTensorsKVCacheMethod K_CACHE/Q_ATTN dispatch.

Covers _has_kq_attn_transform: config reading, per-layer targeting via
Covers _resolve_kv_transform: config reading, per-layer targeting via
is_match, and validation errors raised at model load time.

Also covers apply_kv_cache: transform dispatch on scheme.type and the
no-op path when _ct_kv_transform is None.

These tests do not call the hadacore_transform kernel and therefore
run on all platforms.
"""
Expand Down Expand Up @@ -73,70 +76,76 @@ def _make_layer(layer_name: str = "model.layers.0.self_attn", head_size: int = 1


# ---------------------------------------------------------------------------
# Enabled cases
# _resolve_kv_transform: enabled cases
# ---------------------------------------------------------------------------


def test_kq_transform_set_for_k_cache():
"""K_CACHE location targeting this layer → True."""
def test_kv_transform_set_for_k_cache():
"""K_CACHE location targeting this layer → TransformScheme stored."""
from compressed_tensors.transform import TransformScheme

quant_config = _make_quant_config(_make_transform_config("k_cache"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)

assert layer._kq_attn_transform is True
assert isinstance(layer._ct_kv_transform, TransformScheme)


def test_kq_transform_set_for_q_attn():
"""Q_ATTN location targeting this layer → True."""
def test_kv_transform_set_for_q_attn():
"""Q_ATTN location targeting this layer → TransformScheme stored."""
from compressed_tensors.transform import TransformScheme

quant_config = _make_quant_config(_make_transform_config("q_attn"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)

assert layer._kq_attn_transform is True
assert isinstance(layer._ct_kv_transform, TransformScheme)


# ---------------------------------------------------------------------------
# Disabled cases
# _resolve_kv_transform: disabled cases
# ---------------------------------------------------------------------------


def test_kq_transform_none_without_transform_config():
"""No transform_config → _kq_attn_transform is False."""
def test_kv_transform_none_without_transform_config():
"""No transform_config → _ct_kv_transform is None."""
quant_config = _make_quant_config(transform_config=None)
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)

assert layer._kq_attn_transform is False
assert layer._ct_kv_transform is None


def test_kq_transform_none_for_non_attention_locations():
def test_kv_transform_none_for_non_attention_locations():
"""INPUT/OUTPUT locations should not trigger KV rotation."""
quant_config = _make_quant_config(_make_transform_config("input"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)

assert layer._kq_attn_transform is False
assert layer._ct_kv_transform is None


# ---------------------------------------------------------------------------
# Per-layer targeting
# _resolve_kv_transform: per-layer targeting
# ---------------------------------------------------------------------------


def test_kq_transform_per_layer_targeting():
def test_kv_transform_per_layer_targeting():
"""Scheme targeting only self_attn layers must not fire on other layers.

Real R3 checkpoints target by class name (e.g. LlamaAttention) or regex.
This test uses a regex that matches *self_attn suffixes only.
"""
from compressed_tensors.transform import TransformScheme

quant_config = _make_quant_config(
_make_transform_config("k_cache", targets=["re:.*self_attn"])
)
Expand All @@ -148,14 +157,141 @@ def test_kq_transform_per_layer_targeting():
method.create_weights(attn_layer)
method.create_weights(other_layer)

assert attn_layer._kq_attn_transform is True, (
"self_attn layer should have rotation enabled"
assert isinstance(attn_layer._ct_kv_transform, TransformScheme), (
"self_attn layer should have a resolved TransformScheme"
)
assert other_layer._ct_kv_transform is None, (
"mlp layer should not have a resolved TransformScheme"
)


# ---------------------------------------------------------------------------
# apply_kv_cache: transform dispatch
# ---------------------------------------------------------------------------


def test_apply_query_calls_hadamard_transform():
"""apply_query with a hadamard scheme must rotate query."""
from unittest.mock import patch

quant_config = _make_quant_config(_make_transform_config("k_cache"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)
layer.calculate_kv_scales = False

query = torch.randn(4, 8, 128)
rotated_query = torch.randn(4, 8, 128)

with patch(
"vllm.model_executor.layers.quantization.compressed_tensors"
".compressed_tensors.ops.hadacore_transform",
return_value=rotated_query,
) as mock_rotate:
out_query = method.apply_query(layer, query)

mock_rotate.assert_called_once_with(query)
assert out_query is rotated_query


def test_apply_kv_cache_calls_hadamard_transform():
"""apply_kv_cache with a hadamard scheme must rotate key and write cache."""
from unittest.mock import MagicMock, patch

quant_config = _make_quant_config(_make_transform_config("k_cache"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)
layer.calculate_kv_scales = False
layer.impl = MagicMock()

key = torch.randn(4, 8, 128)
value = torch.randn(4, 8, 128)
kv_cache = torch.zeros(2, 4, 8, 128)
slot_mapping = torch.arange(4)

rotated_key = torch.randn(4, 8, 128)

with patch(
"vllm.model_executor.layers.quantization.compressed_tensors"
".compressed_tensors.ops.hadacore_transform",
return_value=rotated_key,
) as mock_rotate:
method.apply_kv_cache(layer, key, value, kv_cache, slot_mapping)

mock_rotate.assert_called_once_with(key)
layer.impl.do_kv_cache_update.assert_called_once_with(
layer, rotated_key, value, kv_cache, slot_mapping
)
assert other_layer._kq_attn_transform is False, (
"mlp layer should not have rotation enabled"


def test_apply_query_no_transform_when_scheme_is_none():
"""apply_query with no resolved scheme must not rotate query."""
from unittest.mock import patch

quant_config = _make_quant_config(transform_config=None)
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)
layer.calculate_kv_scales = False

query = torch.randn(4, 8, 128)

with patch(
"vllm.model_executor.layers.quantization.compressed_tensors"
".compressed_tensors.ops.hadacore_transform",
) as mock_rotate:
out_query = method.apply_query(layer, query)

mock_rotate.assert_not_called()
assert out_query is query


def test_apply_kv_cache_no_transform_when_scheme_is_none():
"""apply_kv_cache with no resolved scheme must not rotate key."""
from unittest.mock import MagicMock, patch

quant_config = _make_quant_config(transform_config=None)
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
method.create_weights(layer)
layer.calculate_kv_scales = False
layer.impl = MagicMock()

key = torch.randn(4, 8, 128)
value = torch.randn(4, 8, 128)
kv_cache = torch.zeros(2, 4, 8, 128)
slot_mapping = torch.arange(4)

with patch(
"vllm.model_executor.layers.quantization.compressed_tensors"
".compressed_tensors.ops.hadacore_transform",
) as mock_rotate:
method.apply_kv_cache(layer, key, value, kv_cache, slot_mapping)

mock_rotate.assert_not_called()
layer.impl.do_kv_cache_update.assert_called_once_with(
layer, key, value, kv_cache, slot_mapping
)


def test_calculate_kv_scales_raises_at_load_time():
"""calculate_kv_scales=True with a matched transform scheme must raise at
create_weights time, not at forward time."""
quant_config = _make_quant_config(_make_transform_config("k_cache"))
method = CompressedTensorsKVCacheMethod(quant_config)

layer = _make_layer()
layer.calculate_kv_scales = True

with pytest.raises(ValueError, match="calculate_kv_scales"):
method.create_weights(layer)


# ---------------------------------------------------------------------------
# Error cases (all raised at model load, not at runtime)
# ---------------------------------------------------------------------------
Expand Down
18 changes: 16 additions & 2 deletions vllm/compilation/passes/fusion/rope_kvcache_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Attention,
get_attention_context,
)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.utils.torch_utils import direct_register_custom_op

from ..inductor_pass import enable_fake_mode
Expand Down Expand Up @@ -144,8 +145,10 @@ def pattern(
q = q.view(-1, self.num_heads, self.head_size)
k = k.view(-1, self.num_kv_heads, self.head_size)
v = v.view(-1, self.num_kv_heads, self.head_size_v)
dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name)
return dummy, q, k, v
q_out, k_out, dummy = torch.ops.vllm.apply_kv_cache_update(
q, k, v, self.layer_name
)
return dummy, q_out, k_out, v

def replacement(
qkv: torch.Tensor,
Expand Down Expand Up @@ -207,6 +210,17 @@ def __init__(self, config: VllmConfig) -> None:
attn_layers = get_layers_from_vllm_config(config, Attention)
for _, layer in attn_layers.items():
if layer.impl.fused_rope_kvcache_supported():
# If the layer's quant method overrides apply_kv_cache it has
# work to do between RoPE and the cache write (e.g. Hadamard
# rotation). The fused triton kernel cannot accommodate an
# inter-step transform, so leave those layers on the general
# apply_kv_cache_update path which dispatches through the method.
qm = getattr(layer, "quant_method", None)
if qm is not None and (
type(qm).apply_kv_cache is not BaseKVCacheMethod.apply_kv_cache
or type(qm).apply_query is not BaseKVCacheMethod.apply_query
):
continue
for is_neox in [True, False]:
RopeReshapeKVCachePattern(
layer=layer,
Expand Down
4 changes: 3 additions & 1 deletion vllm/config/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def log_enabled_passes(self) -> None:
"""
enabled_fusions = [
f.name[len("fuse_") :]
for f in fields(self)
for f in fields(self) # type: ignore[arg-type]
if getattr(self, f.name) and f.name.startswith("fuse_")
]

Expand Down Expand Up @@ -1084,6 +1084,7 @@ def set_splitting_ops_for_v1(
)
self.pass_config.fuse_rope_kvcache = False
self.splitting_ops.append("vllm::unified_kv_cache_update")
self.splitting_ops.append("vllm::apply_kv_cache_update")
self.splitting_ops.append("vllm::unified_mla_kv_cache_update")

elif len(self.splitting_ops) == 0:
Expand Down Expand Up @@ -1173,6 +1174,7 @@ def splitting_ops_contain_kv_cache_update(self) -> bool:

kv_cache_update_ops = [
"vllm::unified_kv_cache_update",
"vllm::apply_kv_cache_update",
"vllm::unified_mla_kv_cache_update",
]
return self.splitting_ops is not None and all(
Expand Down
Loading
Loading