diff --git a/tests/compile/passes/test_rope_kvcache_fusion.py b/tests/compile/passes/test_rope_kvcache_fusion.py index eea21c9179bd..74b0261b910e 100644 --- a/tests/compile/passes/test_rope_kvcache_fusion.py +++ b/tests/compile/passes/test_rope_kvcache_fusion.py @@ -36,7 +36,6 @@ from vllm.v1.kv_cache_interface import AttentionSpec INDEX_SELECT_OP = torch.ops.aten.index.Tensor -VLLM_UNIFIED_KV_CACHE_UPDATE_OP = torch.ops.vllm.unified_kv_cache_update FP8_DTYPE = current_platform.fp8_dtype() @@ -169,8 +168,8 @@ def forward( q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size) - kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( - k, v, self.layer_name + q, k, kv_cache_dummy_dep = torch.ops.vllm.apply_kv_cache_update( + q, k, v, self.layer_name ) return q, k, v, kv_cache_dummy_dep @@ -183,7 +182,7 @@ def ops_in_model_before(self) -> list[torch._ops.OpOverload]: ops.append(ROTARY_OP) else: ops.append(INDEX_SELECT_OP) - ops.append(torch.ops.vllm.unified_kv_cache_update.default) + ops.append(torch.ops.vllm.apply_kv_cache_update.default) return ops def ops_in_model_after(self) -> list[torch._ops.OpOverload]: diff --git a/tests/quantization/test_hadamard_kv_dispatch.py b/tests/quantization/test_hadamard_kv_dispatch.py index cedf67100729..9c8f7794d35d 100644 --- a/tests/quantization/test_hadamard_kv_dispatch.py +++ b/tests/quantization/test_hadamard_kv_dispatch.py @@ -3,9 +3,12 @@ """ Tests for CompressedTensorsKVCacheMethod K_CACHE/Q_ATTN dispatch. -Covers _has_kq_attn_transform: config reading, per-layer targeting via +Covers _resolve_kv_transform: config reading, per-layer targeting via is_match, and validation errors raised at model load time. +Also covers apply_kv_cache: transform dispatch on scheme.type and the +no-op path when _ct_kv_transform is None. + These tests do not call the hadacore_transform kernel and therefore run on all platforms. """ @@ -73,49 +76,53 @@ def _make_layer(layer_name: str = "model.layers.0.self_attn", head_size: int = 1 # --------------------------------------------------------------------------- -# Enabled cases +# _resolve_kv_transform: enabled cases # --------------------------------------------------------------------------- -def test_kq_transform_set_for_k_cache(): - """K_CACHE location targeting this layer → True.""" +def test_kv_transform_set_for_k_cache(): + """K_CACHE location targeting this layer → TransformScheme stored.""" + from compressed_tensors.transform import TransformScheme + quant_config = _make_quant_config(_make_transform_config("k_cache")) method = CompressedTensorsKVCacheMethod(quant_config) layer = _make_layer() method.create_weights(layer) - assert layer._kq_attn_transform is True + assert isinstance(layer._ct_kv_transform, TransformScheme) -def test_kq_transform_set_for_q_attn(): - """Q_ATTN location targeting this layer → True.""" +def test_kv_transform_set_for_q_attn(): + """Q_ATTN location targeting this layer → TransformScheme stored.""" + from compressed_tensors.transform import TransformScheme + quant_config = _make_quant_config(_make_transform_config("q_attn")) method = CompressedTensorsKVCacheMethod(quant_config) layer = _make_layer() method.create_weights(layer) - assert layer._kq_attn_transform is True + assert isinstance(layer._ct_kv_transform, TransformScheme) # --------------------------------------------------------------------------- -# Disabled cases +# _resolve_kv_transform: disabled cases # --------------------------------------------------------------------------- -def test_kq_transform_none_without_transform_config(): - """No transform_config → _kq_attn_transform is False.""" +def test_kv_transform_none_without_transform_config(): + """No transform_config → _ct_kv_transform is None.""" quant_config = _make_quant_config(transform_config=None) method = CompressedTensorsKVCacheMethod(quant_config) layer = _make_layer() method.create_weights(layer) - assert layer._kq_attn_transform is False + assert layer._ct_kv_transform is None -def test_kq_transform_none_for_non_attention_locations(): +def test_kv_transform_none_for_non_attention_locations(): """INPUT/OUTPUT locations should not trigger KV rotation.""" quant_config = _make_quant_config(_make_transform_config("input")) method = CompressedTensorsKVCacheMethod(quant_config) @@ -123,20 +130,22 @@ def test_kq_transform_none_for_non_attention_locations(): layer = _make_layer() method.create_weights(layer) - assert layer._kq_attn_transform is False + assert layer._ct_kv_transform is None # --------------------------------------------------------------------------- -# Per-layer targeting +# _resolve_kv_transform: per-layer targeting # --------------------------------------------------------------------------- -def test_kq_transform_per_layer_targeting(): +def test_kv_transform_per_layer_targeting(): """Scheme targeting only self_attn layers must not fire on other layers. Real R3 checkpoints target by class name (e.g. LlamaAttention) or regex. This test uses a regex that matches *self_attn suffixes only. """ + from compressed_tensors.transform import TransformScheme + quant_config = _make_quant_config( _make_transform_config("k_cache", targets=["re:.*self_attn"]) ) @@ -148,14 +157,141 @@ def test_kq_transform_per_layer_targeting(): method.create_weights(attn_layer) method.create_weights(other_layer) - assert attn_layer._kq_attn_transform is True, ( - "self_attn layer should have rotation enabled" + assert isinstance(attn_layer._ct_kv_transform, TransformScheme), ( + "self_attn layer should have a resolved TransformScheme" + ) + assert other_layer._ct_kv_transform is None, ( + "mlp layer should not have a resolved TransformScheme" + ) + + +# --------------------------------------------------------------------------- +# apply_kv_cache: transform dispatch +# --------------------------------------------------------------------------- + + +def test_apply_query_calls_hadamard_transform(): + """apply_query with a hadamard scheme must rotate query.""" + from unittest.mock import patch + + quant_config = _make_quant_config(_make_transform_config("k_cache")) + method = CompressedTensorsKVCacheMethod(quant_config) + + layer = _make_layer() + method.create_weights(layer) + layer.calculate_kv_scales = False + + query = torch.randn(4, 8, 128) + rotated_query = torch.randn(4, 8, 128) + + with patch( + "vllm.model_executor.layers.quantization.compressed_tensors" + ".compressed_tensors.ops.hadacore_transform", + return_value=rotated_query, + ) as mock_rotate: + out_query = method.apply_query(layer, query) + + mock_rotate.assert_called_once_with(query) + assert out_query is rotated_query + + +def test_apply_kv_cache_calls_hadamard_transform(): + """apply_kv_cache with a hadamard scheme must rotate key and write cache.""" + from unittest.mock import MagicMock, patch + + quant_config = _make_quant_config(_make_transform_config("k_cache")) + method = CompressedTensorsKVCacheMethod(quant_config) + + layer = _make_layer() + method.create_weights(layer) + layer.calculate_kv_scales = False + layer.impl = MagicMock() + + key = torch.randn(4, 8, 128) + value = torch.randn(4, 8, 128) + kv_cache = torch.zeros(2, 4, 8, 128) + slot_mapping = torch.arange(4) + + rotated_key = torch.randn(4, 8, 128) + + with patch( + "vllm.model_executor.layers.quantization.compressed_tensors" + ".compressed_tensors.ops.hadacore_transform", + return_value=rotated_key, + ) as mock_rotate: + method.apply_kv_cache(layer, key, value, kv_cache, slot_mapping) + + mock_rotate.assert_called_once_with(key) + layer.impl.do_kv_cache_update.assert_called_once_with( + layer, rotated_key, value, kv_cache, slot_mapping ) - assert other_layer._kq_attn_transform is False, ( - "mlp layer should not have rotation enabled" + + +def test_apply_query_no_transform_when_scheme_is_none(): + """apply_query with no resolved scheme must not rotate query.""" + from unittest.mock import patch + + quant_config = _make_quant_config(transform_config=None) + method = CompressedTensorsKVCacheMethod(quant_config) + + layer = _make_layer() + method.create_weights(layer) + layer.calculate_kv_scales = False + + query = torch.randn(4, 8, 128) + + with patch( + "vllm.model_executor.layers.quantization.compressed_tensors" + ".compressed_tensors.ops.hadacore_transform", + ) as mock_rotate: + out_query = method.apply_query(layer, query) + + mock_rotate.assert_not_called() + assert out_query is query + + +def test_apply_kv_cache_no_transform_when_scheme_is_none(): + """apply_kv_cache with no resolved scheme must not rotate key.""" + from unittest.mock import MagicMock, patch + + quant_config = _make_quant_config(transform_config=None) + method = CompressedTensorsKVCacheMethod(quant_config) + + layer = _make_layer() + method.create_weights(layer) + layer.calculate_kv_scales = False + layer.impl = MagicMock() + + key = torch.randn(4, 8, 128) + value = torch.randn(4, 8, 128) + kv_cache = torch.zeros(2, 4, 8, 128) + slot_mapping = torch.arange(4) + + with patch( + "vllm.model_executor.layers.quantization.compressed_tensors" + ".compressed_tensors.ops.hadacore_transform", + ) as mock_rotate: + method.apply_kv_cache(layer, key, value, kv_cache, slot_mapping) + + mock_rotate.assert_not_called() + layer.impl.do_kv_cache_update.assert_called_once_with( + layer, key, value, kv_cache, slot_mapping ) +def test_calculate_kv_scales_raises_at_load_time(): + """calculate_kv_scales=True with a matched transform scheme must raise at + create_weights time, not at forward time.""" + quant_config = _make_quant_config(_make_transform_config("k_cache")) + method = CompressedTensorsKVCacheMethod(quant_config) + + layer = _make_layer() + layer.calculate_kv_scales = True + + with pytest.raises(ValueError, match="calculate_kv_scales"): + method.create_weights(layer) + + # --------------------------------------------------------------------------- # Error cases (all raised at model load, not at runtime) # --------------------------------------------------------------------------- diff --git a/vllm/compilation/passes/fusion/rope_kvcache_fusion.py b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py index 830a9640780c..77c46c284a2a 100644 --- a/vllm/compilation/passes/fusion/rope_kvcache_fusion.py +++ b/vllm/compilation/passes/fusion/rope_kvcache_fusion.py @@ -15,6 +15,7 @@ Attention, get_attention_context, ) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.utils.torch_utils import direct_register_custom_op from ..inductor_pass import enable_fake_mode @@ -144,8 +145,10 @@ def pattern( q = q.view(-1, self.num_heads, self.head_size) k = k.view(-1, self.num_kv_heads, self.head_size) v = v.view(-1, self.num_kv_heads, self.head_size_v) - dummy = torch.ops.vllm.unified_kv_cache_update(k, v, self.layer_name) - return dummy, q, k, v + q_out, k_out, dummy = torch.ops.vllm.apply_kv_cache_update( + q, k, v, self.layer_name + ) + return dummy, q_out, k_out, v def replacement( qkv: torch.Tensor, @@ -207,6 +210,17 @@ def __init__(self, config: VllmConfig) -> None: attn_layers = get_layers_from_vllm_config(config, Attention) for _, layer in attn_layers.items(): if layer.impl.fused_rope_kvcache_supported(): + # If the layer's quant method overrides apply_kv_cache it has + # work to do between RoPE and the cache write (e.g. Hadamard + # rotation). The fused triton kernel cannot accommodate an + # inter-step transform, so leave those layers on the general + # apply_kv_cache_update path which dispatches through the method. + qm = getattr(layer, "quant_method", None) + if qm is not None and ( + type(qm).apply_kv_cache is not BaseKVCacheMethod.apply_kv_cache + or type(qm).apply_query is not BaseKVCacheMethod.apply_query + ): + continue for is_neox in [True, False]: RopeReshapeKVCachePattern( layer=layer, diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 716c208a9048..04e34e63fe4a 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -282,7 +282,7 @@ def log_enabled_passes(self) -> None: """ enabled_fusions = [ f.name[len("fuse_") :] - for f in fields(self) + for f in fields(self) # type: ignore[arg-type] if getattr(self, f.name) and f.name.startswith("fuse_") ] @@ -1084,6 +1084,7 @@ def set_splitting_ops_for_v1( ) self.pass_config.fuse_rope_kvcache = False self.splitting_ops.append("vllm::unified_kv_cache_update") + self.splitting_ops.append("vllm::apply_kv_cache_update") self.splitting_ops.append("vllm::unified_mla_kv_cache_update") elif len(self.splitting_ops) == 0: @@ -1173,6 +1174,7 @@ def splitting_ops_contain_kv_cache_update(self) -> bool: kv_cache_update_ops = [ "vllm::unified_kv_cache_update", + "vllm::apply_kv_cache_update", "vllm::unified_mla_kv_cache_update", ] return self.splitting_ops is not None and all( diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 3d5e6fcf82ba..aaa4f2cbc52b 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -6,7 +6,6 @@ import torch import torch.nn as nn -import vllm._custom_ops as ops import vllm.envs as envs from vllm.config import CacheConfig, get_current_vllm_config from vllm.config.vllm import VllmConfig @@ -378,26 +377,9 @@ def __init__( # this variable will not be accessed if use_direct_call is True self.kv_cache = torch.tensor([]) - # Set to True by CompressedTensorsKVCacheMethod.create_weights when the - # checkpoint's transform_config targets this layer with K_CACHE/Q_ATTN - # locations. Must be initialised before _init_kv_cache_quant. - self._kq_attn_transform = False - # Initialize KV cache quantization attributes _init_kv_cache_quant(self, quant_config, prefix) - # Guard against combining Hadamard rotation with online KV scale. - # maybe_calc_kv_scales runs before the rotation, so scales would be - # computed on unrotated K. - if self._kq_attn_transform and self.calculate_kv_scales: - raise ValueError( - f"Layer '{prefix}': cannot combine K_CACHE/Q_ATTN Hadamard " - "rotation with calculate_kv_scales=True. The KV scale " - "computation (maybe_calc_kv_scales) runs before the rotation " - "and would produce scales for unrotated K. Either disable " - "calculate_kv_scales or remove the transform config." - ) - # for attn backends supporting query quantization self.query_quant = None if ( @@ -470,26 +452,11 @@ def forward( if value is not None: value = value.view(-1, self.num_kv_heads, self.head_size_v) - # K_CACHE/Q_ATTN Hadamard rotation (R3/SpinQuant-style). - # Applied post-RoPE, post-reshape, before both the KV cache write - # and the attention computation, so prefill and decode are consistent. - # K_CACHE and Q_ATTN share the same rotation matrix; V and output - # un-rotation are absorbed offline into W_v and W_o weights. - if self._kq_attn_transform and key is not None: - query = ops.hadacore_transform(query) - key = ops.hadacore_transform(key) - kv_cache_dummy_dep = None if self.use_direct_call: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = unified_kv_cache_update( - key, value, self.layer_name + if key is not None and value is not None: + query, key, kv_cache_dummy_dep = apply_kv_cache_update( + query, key, value, self.layer_name ) unified_attention_with_output( query, @@ -500,15 +467,11 @@ def forward( kv_cache_dummy_dep=kv_cache_dummy_dep, ) else: - # Skip this if sharing KV cache with an earlier attention layer. - if ( - not self.attn_backend.forward_includes_kv_cache_update - and self.kv_sharing_target_layer_name is None - and key is not None - and value is not None - ): - kv_cache_dummy_dep = torch.ops.vllm.unified_kv_cache_update( - key, value, self.layer_name + if key is not None and value is not None: + query, key, kv_cache_dummy_dep = ( + torch.ops.vllm.apply_kv_cache_update( + query, key, value, self.layer_name + ) ) torch.ops.vllm.unified_attention_with_output( query, @@ -734,6 +697,81 @@ def unified_kv_cache_update_fake( ) +def apply_kv_cache_update( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply pre-cache transforms and write key/value to the paged KV cache. + + This is the MHA migration target for ``unified_kv_cache_update``. When + the attention layer has a ``quant_method``, it dispatches to + ``BaseKVCacheMethod.apply_kv_cache``, which is responsible for both any + pre-cache transforms (e.g. Hadamard rotation) and the cache write. + Without a ``quant_method`` it falls back to + ``attn_layer.impl.do_kv_cache_update``, preserving existing behaviour. + + The cache write is skipped (transforms still applied) when: + - ``forward_includes_kv_cache_update`` is True — the backend writes during + its own ``forward()`` call. + - ``kv_sharing_target_layer_name`` is set — this layer shares the cache + written by an earlier layer. + - ``slot_mapping`` is None — nothing to write (e.g. profiling run). + + Returns ``(query, key, dummy_dep)`` where ``dummy_dep`` is a zero-element + tensor threaded into ``unified_attention_with_output`` to preserve + ordering under torch.compile. + + Note: MLA uses a different tensor layout and a separate forward path; + this function is not called from ``MLAAttention.forward()``. + """ + _, attn_layer, kv_cache, slot_mapping = get_attention_context(layer_name) + + should_write = ( + not attn_layer.attn_backend.forward_includes_kv_cache_update + and attn_layer.kv_sharing_target_layer_name is None + and slot_mapping is not None + ) + + qm = getattr(attn_layer, "quant_method", None) + + # Q transform is unconditional: Q must be prepared for attention on every + # forward pass regardless of whether a cache write is happening. + if qm is not None: + query = qm.apply_query(attn_layer, query) + + if should_write: + if qm is not None: + qm.apply_kv_cache(attn_layer, key, value, kv_cache, slot_mapping) + else: + assert hasattr(attn_layer.impl, "do_kv_cache_update"), ( + f"{attn_layer.impl.__class__.__name__} does not support kv cache update" + ) + attn_layer.impl.do_kv_cache_update( + attn_layer, key, value, kv_cache, slot_mapping + ) + + return query, key, torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) + + +def apply_kv_cache_update_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + layer_name: str, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return query, key, torch.empty(0, device=query.device, dtype=query.dtype) + + +direct_register_custom_op( + op_name="apply_kv_cache_update", + op_func=apply_kv_cache_update, + fake_impl=apply_kv_cache_update_fake, + mutates_args=[], +) + + @maybe_transfer_kv_layer def unified_attention_with_output( query: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index aff7430327e0..fee65b09e115 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -16,9 +16,14 @@ QuantizationStrategy, QuantizationType, ) -from compressed_tensors.transform import TransformConfig, TransformLocation +from compressed_tensors.transform import ( + TransformConfig, + TransformLocation, + TransformScheme, +) from compressed_tensors.utils import is_match +import vllm._custom_ops as ops from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -1094,19 +1099,22 @@ def _tp_aware_loader( _tp_aware_loader, kind="v", param_type="zero_point" ) - # Wire up KQ cache Hadamard rotation. - # K_CACHE and Q_ATTN share the same rotation matrix. - # The inverse rotation for V and output is done in W_v/W_o by llm-compressor. - layer._kq_attn_transform = self._has_kq_attn_transform(layer) + # Resolve at model-load time which KV cache transform (if any) applies + # to this layer. Stored on the layer so apply_kv_cache can dispatch + # at forward time without re-walking transform_config every call. + layer._ct_kv_transform = self._resolve_kv_transform(layer) - def _has_kq_attn_transform(self, layer: torch.nn.Module) -> bool: - """Return True if this layer should have K_CACHE/Q_ATTN Hadamard rotation. + def _resolve_kv_transform(self, layer: torch.nn.Module) -> TransformScheme | None: + """Return the TransformScheme for K_CACHE/Q_ATTN transforms on this + layer, or None if no transform applies. - Walks the transform_config looking for K_CACHE or Q_ATTN locations - whose targets match this specific layer. + Walks transform_config looking for K_CACHE or Q_ATTN locations whose + targets match this layer. Validates the scheme and raises at model + load time if the configuration is unsupported, so errors surface + before the first forward pass. """ if self.quant_config.transform_config is None: - return False + return None layer_name = getattr(layer, "layer_name", "") head_dim = getattr(layer, "head_size", None) @@ -1121,59 +1129,112 @@ def _has_kq_attn_transform(self, layer: torch.nn.Module) -> bool: TransformLocation.Q_ATTN, ): continue - if current_platform.is_rocm(): - raise NotImplementedError( - "K_CACHE/Q_ATTN Hadamard rotation requires the " - "hadacore_transform kernel, which is not supported " - "on ROCm." - ) if not is_match(layer_name, layer, args.targets, args.ignore): continue - if scheme.type != "hadamard": + if current_platform.is_rocm(): + raise NotImplementedError( + "K_CACHE/Q_ATTN transforms require the hadacore_transform " + "kernel, which is not supported on ROCm." + ) + if scheme.type == "random-hadamard" or scheme.randomize: raise NotImplementedError( - f"KV cache rotation type '{scheme.type}' is not " - "supported. Only 'hadamard' (deterministic Sylvester " - "construction) is implemented. 'random-hadamard' and " - "'random-matrix' require loading a stored rotation " - "matrix, which is not yet implemented." + f"KV cache transform type '{scheme.type}' with " + f"randomize={scheme.randomize} requires loading a " + "per-layer rotation matrix from the checkpoint, which " + "is not yet implemented for the KV cache path. See " + "HadamardTransform in compressed_tensors for the " + "naming-gap and TP-check blockers." ) - if scheme.randomize: + if scheme.type != "hadamard": raise NotImplementedError( - "KV cache Hadamard rotation with randomize=True is " - "not supported. Randomized transforms require loading " - "a per-layer rotation matrix from the checkpoint, " - "which is not yet implemented." + f"KV cache transform type '{scheme.type}' is not " + "supported. Only 'hadamard' (deterministic) is " + "currently implemented." ) if head_dim is None: raise ValueError( - f"Layer '{layer_name}': K_CACHE/Q_ATTN Hadamard " - "rotation requires head_size attribute." + f"Layer '{layer_name}': K_CACHE/Q_ATTN transform " + "requires head_size attribute." ) if scheme.head_dim is not None and scheme.head_dim != head_dim: raise ValueError( f"Layer '{layer_name}': transform_config head_dim " f"({scheme.head_dim}) does not match layer head_size " - f"({head_dim}). K_CACHE/Q_ATTN rotation operates at " + f"({head_dim}). K_CACHE/Q_ATTN transforms operate at " "head granularity." ) if head_dim <= 0 or (head_dim & (head_dim - 1)) != 0: raise ValueError( - f"KV cache Hadamard rotation requires head_dim to be " - f"a power of two, got {head_dim} for layer " - f"'{layer_name}'." + f"K_CACHE/Q_ATTN transform requires head_dim to be a " + f"power of two, got {head_dim} for layer '{layer_name}'." ) if head_dim > 2**15: raise ValueError( - f"KV cache Hadamard rotation requires head_dim <= " - f"2^15 (hadacore kernel constraint), got {head_dim} " - f"for layer '{layer_name}'." + f"K_CACHE/Q_ATTN transform requires head_dim <= 2^15 " + f"(hadacore kernel constraint), got {head_dim} for " + f"layer '{layer_name}'." ) - return True + if getattr(layer, "calculate_kv_scales", False): + raise ValueError( + f"Layer '{layer_name}': cannot combine a K_CACHE/Q_ATTN " + "transform with calculate_kv_scales=True. KV scale " + "computation runs before the transform and would produce " + "scales for un-transformed K." + ) - return False + return scheme + + return None + + def apply_query( + self, + layer: torch.nn.Module, + query: torch.Tensor, + ) -> torch.Tensor: + """Apply the Q_ATTN Hadamard rotation before the attention kernel. + + Called unconditionally on every forward pass so that Q is correctly + rotated for both prefill and decode, including decode steps where Q + attends against previously cached (rotated) K. + """ + scheme = getattr(layer, "_ct_kv_transform", None) + if scheme is not None and scheme.type == "hadamard": + query = ops.hadacore_transform(query) + return query + + def apply_kv_cache( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Apply the K_CACHE Hadamard rotation then write to the paged cache. + + Dispatches on the TransformScheme stored by _resolve_kv_transform at + model load time. The cache write is delegated to the base class. + + Currently supported scheme types: + "hadamard" — deterministic Sylvester FWHT via ops.hadacore_transform. + K_CACHE and Q_ATTN share one rotation matrix; Q rotation is handled + in apply_query. The inverse for V and output is absorbed offline + into W_v and W_o by llm-compressor. + + Not yet supported: + "random-hadamard" — requires loading the per-layer rotation matrix + from the checkpoint. Blocked by the naming gap between the + HuggingFace self_attn prefix (where R3 weights live) and the + self_attn.attn prefix where create_weights runs, and by the + incorrect TP > 1 guard in HadamardTransform.__init__. + """ + scheme = getattr(layer, "_ct_kv_transform", None) + if scheme is not None and scheme.type == "hadamard": + key = ops.hadacore_transform(key) + super().apply_kv_cache(layer, key, value, kv_cache, slot_mapping) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: """ diff --git a/vllm/model_executor/layers/quantization/kv_cache.py b/vllm/model_executor/layers/quantization/kv_cache.py index 726ac2232af9..3a79247136cc 100644 --- a/vllm/model_executor/layers/quantization/kv_cache.py +++ b/vllm/model_executor/layers/quantization/kv_cache.py @@ -46,6 +46,66 @@ def create_weights(self, layer: torch.nn.Module): def apply(self, layer: torch.nn.Module) -> torch.Tensor: raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.") + def apply_query( + self, + layer: torch.nn.Module, + query: torch.Tensor, + ) -> torch.Tensor: + """Transform query before the attention kernel. + + Called unconditionally from ``apply_kv_cache_update`` on every forward + pass, after RoPE and reshape. Subclasses that rotate or otherwise + transform Q (e.g. Hadamard Q_ATTN rotation) should override this. + + The default implementation is the identity. + + Note: this interface covers standard decoder-only MHA. For + architectures where Q is computed separately from the cache write + (e.g. encoder-decoder cross-attention), this hook is not called and + the transform must be handled differently. + + Args: + layer: the ``Attention`` layer instance. + query: ``[num_tokens, num_heads, head_size]``. + + Returns: + Transformed query tensor. + """ + return query + + def apply_kv_cache( + self, + layer: torch.nn.Module, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + ) -> None: + """Apply transforms and write key/value to the paged KV cache. + + Called from ``apply_kv_cache_update`` only when a cache write should + occur (``should_write`` is True). Subclasses that need pre-cache + transforms on K (e.g. Hadamard rotation) should override this method. + Call ``super().apply_kv_cache(...)`` to delegate the write after + transforming. + + Query transforms belong in ``apply_query``, which is called + unconditionally on every forward pass. + + Note: this interface covers standard MHA only. MLA uses a separate + code path. + + Args: + layer: the ``Attention`` layer instance. + key: ``[num_tokens, num_kv_heads, head_size]`` — transformed + in-place and written to the paged cache. + value: ``[num_tokens, num_kv_heads, head_size_v]`` — written to + the paged cache unchanged by the default implementation. + kv_cache: paged KV cache tensor. + slot_mapping: token-to-slot mapping for the current batch. + """ + layer.impl.do_kv_cache_update(layer, key, value, kv_cache, slot_mapping) + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # skip if there are no weights to process (for example, weight reloading) if not hasattr(layer, "q_scale"):