From 12fae71cb52c7c291d473a28a7e349098b42c16d Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 1 May 2025 10:19:16 -0700 Subject: [PATCH] [Executorch][llm] Enable leveraging ring kv cache via module swap This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned] --- examples/models/llama/attention.py | 18 +- .../source_transformation/custom_kv_cache.py | 191 +++++++++++++++++- examples/models/llama/tests/TARGETS | 25 +++ .../llama/tests/test_replace_kv_cache.py | 158 +++++++++++++++ .../models/llama/tests/test_ring_attention.py | 161 ++++++++++++--- 5 files changed, 522 insertions(+), 31 deletions(-) create mode 100644 examples/models/llama/tests/test_replace_kv_cache.py diff --git a/examples/models/llama/attention.py b/examples/models/llama/attention.py index 862b43bc3f5..ef3690df0fe 100644 --- a/examples/models/llama/attention.py +++ b/examples/models/llama/attention.py @@ -150,6 +150,16 @@ def forward( return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) +def _create_causal_mask_for_ring_buffer( + cache_positions, window_size, start_pos, seq_len +): + pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) + delta = pos_q - cache_positions + attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < window_size) + attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 + return attn_mask + + class CacheUpdateStrategy(Enum): RING_BUFFER = "RingBuffer" INVALID = "Invalid" @@ -283,12 +293,10 @@ def __init__( self.is_ring_buffer = True def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): - pos_q = start_pos + torch.arange(seq_len, dtype=torch.long).view(-1, 1) cache_positions = self.cache_positions_manager.cache_positions - delta = pos_q - cache_positions - attn_mask = (cache_positions >= 0) & (delta >= 0) & (delta < self.window_size) - attn_mask = torch.where(attn_mask == True, 0, float("-inf")) # noqa E712 - return attn_mask + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) def update( self, input_pos: torch.Tensor, k_val: torch.Tensor, v_val: torch.Tensor diff --git a/examples/models/llama/source_transformation/custom_kv_cache.py b/examples/models/llama/source_transformation/custom_kv_cache.py index 9361204f6bc..24038959dba 100644 --- a/examples/models/llama/source_transformation/custom_kv_cache.py +++ b/examples/models/llama/source_transformation/custom_kv_cache.py @@ -10,7 +10,12 @@ import torch import torch.nn as nn -from executorch.examples.models.llama.attention import KVCache +from executorch.examples.models.llama.attention import ( + _create_causal_mask_for_ring_buffer, + CachePositionsManager, + KVCache, + RingKVCache, +) from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401 @@ -75,6 +80,7 @@ def __init__( self.register_buffer( "v_cache_zero_points", torch.ones(scale_shape, dtype=torch.int8) ) + self.cache_type = cache_type def _quantize(self, value): ( @@ -181,6 +187,7 @@ def update(self, input_pos, k_val, v_val, indices=None): However the storage is [B, S, H, D] so we incur transpose in, transpose out This shall be removed by subsequent post-export graph pass """ + k_val = k_val.transpose(1, 2) v_val = v_val.transpose(1, 2) @@ -346,3 +353,185 @@ def _replace_kv_cache_with_custom_kv_cache(module): else: _replace_kv_cache_with_custom_kv_cache(child) return module + + +class QuantizedRingKVCache(QuantizedKVCache): + def __init__( + self, + max_batch_size, + max_context_length, + n_heads, + head_dim, + cache_type: QuantizedCacheType = QuantizedCacheType.AffineSymmetric, + use_custom_update_cache_op: bool = False, + ): + # Look at attention.py for explanation on why max_context_length * 2 + super().__init__( + max_batch_size, + max_context_length * 2, + n_heads, + head_dim, + cache_type, + use_custom_update_cache_op, + ) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + self.window_size = max_context_length + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + cache_positions = self.cache_positions_manager.cache_positions + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + # Need to transpose for two reasons + # 1. kv cache is stored as [B, S, H, D] + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize + # away transpose at the output of k, v projection + seq_len = k_val.transpose(1, 2).size(1) + assert seq_len <= self.k_cache.size( + 1 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + indices = indices.unsqueeze(0) + + return super().update(input_pos, k_val, v_val, indices) + + @classmethod + def from_quantized_kv_cache( + cls, + kv_cache, + sliding_window_size, + ): + assert isinstance( + kv_cache, QuantizedKVCache + ), "For QuantizedRingKVCache expect QuantizedKVCache as input kv_cache" + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + return cls( + max_batch_size, + sliding_window_size, + n_heads, + head_dim, + kv_cache.cache_type, + kv_cache.use_custom_update_cache_op, + ) + + +class CustomRingKVCache(CustomKVCache): + def __init__( + self, + max_batch_size, + max_context_length, + n_heads, + head_dim, + dtype=torch.float32, + ): + # Look at attention.py for explanation on why max_context_length * 2 + super().__init__( + max_batch_size, max_context_length * 2, n_heads, head_dim, dtype + ) + self.cache_positions_manager = CachePositionsManager(self.max_context_length) + self.is_ring_buffer = True + self.window_size = max_context_length + + def create_causal_mask_for_ring_buffer(self, start_pos, seq_len): + cache_positions = self.cache_positions_manager.cache_positions + return _create_causal_mask_for_ring_buffer( + cache_positions, self.window_size, start_pos, seq_len + ) + + def update(self, input_pos, k_val, v_val): + """ + k_val, v_val: [B, H, S, D] + return: [B, H, S, D] + However the storage is [B, S, H, D] so we incur transpose in, transpose out + This shall be removed by subsequent post-export graph pass + """ + # Need to transpose for two reasons + # 1. kv cache is stored as [B, S, H, D] + # 2. If seq_len = k_val.size(2), we wont be able be able to optimize + # away transpose at the output of k, v projection + seq_len = k_val.transpose(1, 2).size(1) + assert seq_len <= self.k_cache.size( + 1 + ), f"Update sequence length({seq_len}) for kv cache must be smaller than the cache size({self.k_cache.size(2)})" + indices = self.cache_positions_manager.calculate_positions_and_update_indices( + input_pos, seq_len + ) + indices = indices.unsqueeze(0) + + return super().update(input_pos, k_val, v_val, indices) + + @classmethod + def from_custom_kv_cache( + cls, + kv_cache, + sliding_window_size, + ): + max_batch_size, n_heads, _, head_dim = kv_cache.k_cache.shape + if isinstance(kv_cache, CustomKVCache): + # If replacing custom kv cache, then the shape is [B, S, H, D] + max_batch_size, _, n_heads, head_dim = kv_cache.k_cache.shape + return cls( + max_batch_size, + sliding_window_size, + n_heads, + head_dim, + dtype=kv_cache.k_cache.dtype, + ) + + +def _replace_kv_cache_with_ring_kv_cache(attention, layer_size): + sliding_window_size = layer_size + assert ( + getattr(attention, "kv_cache", None) is not None + ), "Attention module must have kv_cache module" + kv_cache = attention.kv_cache + if isinstance(kv_cache, KVCache): + attention.kv_cache = RingKVCache( + kv_cache.max_batch_size, + sliding_window_size, + kv_cache.n_heads, + kv_cache.head_dim, + kv_cache.enable_dynamic_shape, + kv_cache.k_cache.dtype, + ) + elif isinstance(kv_cache, CustomKVCache): + attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( + kv_cache, layer_size + ) + elif isinstance(kv_cache, QuantizedKVCache): + attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( + kv_cache, layer_size + ) + + +def replace_kv_cache_with_ring_kv_cache(module, layer_sizes): + # This is needed to ensure that custom ops are registered + from executorch.extension.llm.custom_ops import custom_ops # noqa: F401 + + logging.info( + "Replacing kv cache with ring kv cache. This modifies the model in place." + ) + assert len(layer_sizes) == len( + module.layers + ), f"Length of layer sizes {len(layer_sizes)} must match the number of layers in the module {len(module.layers)}." + for i, transformer_block in enumerate(module.layers): + sliding_window_size = layer_sizes[i] + if sliding_window_size == 0: + continue + assert ( + getattr(transformer_block, "attention", None) is not None + ), f"Transfomer block must have attention module. Transformer block {transformer_block}" + attention = transformer_block.attention + _replace_kv_cache_with_ring_kv_cache(attention, sliding_window_size) + return module diff --git a/examples/models/llama/tests/TARGETS b/examples/models/llama/tests/TARGETS index 0d52cfa19d3..40ab6653c60 100644 --- a/examples/models/llama/tests/TARGETS +++ b/examples/models/llama/tests/TARGETS @@ -55,8 +55,33 @@ python_unittest( srcs = [ "test_ring_attention.py", ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], deps = [ "//caffe2:torch", + "//executorch/examples/models/llama:export_library", + "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama:custom_kv_cache", + "//executorch/examples/models/llama:sdpa", + ], +) + +python_unittest( + name = "test_replace_kv_cache", + srcs = [ + "test_replace_kv_cache.py", + ], + preload_deps = [ + "//executorch/extension/llm/custom_ops:custom_ops_aot_lib", + "//executorch/kernels/quantized:aot_lib", + ], + deps = [ + "//caffe2:torch", + "//executorch/examples/models/llama:export_library", "//executorch/examples/models/llama:llama_transformer", + "//executorch/examples/models/llama:custom_kv_cache", + "//executorch/examples/models/llama:sdpa", ], ) diff --git a/examples/models/llama/tests/test_replace_kv_cache.py b/examples/models/llama/tests/test_replace_kv_cache.py new file mode 100644 index 00000000000..8d7171633b2 --- /dev/null +++ b/examples/models/llama/tests/test_replace_kv_cache.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from typing import List + +import torch.nn as nn + +from executorch.examples.models.llama.attention import ( + Attention, + AttentionMHA, + KVCache, + RingKVCache, + Rope, +) +from executorch.examples.models.llama.model_args import ModelArgs +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + QuantizedKVCache, + QuantizedRingKVCache, + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, + replace_kv_cache_with_ring_kv_cache, +) + + +class MockTransformerBlock(nn.Module): + def __init__(self, attention: Attention): + super().__init__() + self.attention = attention + + +class TestReplaceKVCache(unittest.TestCase): + def setUp(self): + # Common parameters for creating attention modules + self.batch_size = 2 + self.seq_len = 10 + self.dim = 32 + self.n_heads = 4 + self.n_kv_heads = 2 + self.head_dim = 8 + self.max_context_len = 16 + self.enable_dynamic_shape = True + + # Create model args + self.args = ModelArgs( + dim=self.dim, + n_heads=self.n_heads, + n_kv_heads=self.n_kv_heads, + head_dim=self.head_dim, + max_batch_size=self.batch_size, + max_context_len=self.max_context_len, + use_kv_cache=True, + enable_dynamic_shape=self.enable_dynamic_shape, + ) + + # Create a rope instance + self.rope = Rope(self.args) + + def _create_attention_with_kv_cache(self) -> Attention: + """Create an attention module with KVCache.""" + return AttentionMHA(self.args, layer_id=0, rope=self.rope) + + def _create_mock_model(self, attention_modules: List[Attention]) -> nn.Module: + """Create a mock model with transformer blocks containing the given attention modules.""" + model = nn.Module() + model.layers = nn.ModuleList( + [MockTransformerBlock(attention) for attention in attention_modules] + ) + return model + + def test_replace_kv_cache_with_ring_kv_cache(self): + """Test replacing KVCache with RingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Verify that the model has KVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, KVCache) + self.assertNotIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + + # Replace KVCache with RingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that KVCache has been replaced with RingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + + # Verify that the sliding window size is set correctly + self.assertEqual(model.layers[0].attention.kv_cache.window_size, layer_sizes[0]) + + def test_replace_custom_kv_cache_with_custom_ring_kv_cache(self): + """Test replacing CustomKVCache with CustomRingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Replace KVCache with CustomKVCache + replace_kv_cache_with_custom_kv_cache(model) + + # Verify that the model has CustomKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, CustomKVCache) + self.assertNotIsInstance(model.layers[0].attention.kv_cache, CustomRingKVCache) + + # Replace CustomKVCache with CustomRingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that CustomKVCache has been replaced with CustomRingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, CustomRingKVCache) + + def test_replace_quantized_kv_cache_with_quantized_ring_kv_cache(self): + """Test replacing QuantizedKVCache with QuantizedRingKVCache.""" + # Create a model with KVCache + attention = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention]) + + # Replace KVCache with QuantizedKVCache + replace_kv_cache_with_quantized_kv_cache(model) + + # Verify that the model has QuantizedKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, QuantizedKVCache) + self.assertNotIsInstance( + model.layers[0].attention.kv_cache, QuantizedRingKVCache + ) + + # Replace QuantizedKVCache with QuantizedRingKVCache + layer_sizes = [8] # Sliding window size for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that QuantizedKVCache has been replaced with QuantizedRingKVCache + self.assertIsInstance(model.layers[0].attention.kv_cache, QuantizedRingKVCache) + + def test_multiple_layers_with_different_window_sizes(self): + """Test replacing KV caches in multiple layers with different window sizes.""" + # Create a model with multiple layers + attention1 = self._create_attention_with_kv_cache() + attention2 = self._create_attention_with_kv_cache() + attention3 = self._create_attention_with_kv_cache() + model = self._create_mock_model([attention1, attention2, attention3]) + + # Replace KVCache with RingKVCache with different window sizes + layer_sizes = [4, 8, 16] # Different sliding window sizes for each layer + replace_kv_cache_with_ring_kv_cache(model, layer_sizes) + + # Verify that each layer has the correct window size + self.assertIsInstance(model.layers[0].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[0].attention.kv_cache.window_size, layer_sizes[0]) + + self.assertIsInstance(model.layers[1].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[1].attention.kv_cache.window_size, layer_sizes[1]) + + self.assertIsInstance(model.layers[2].attention.kv_cache, RingKVCache) + self.assertEqual(model.layers[2].attention.kv_cache.window_size, layer_sizes[2]) diff --git a/examples/models/llama/tests/test_ring_attention.py b/examples/models/llama/tests/test_ring_attention.py index a3f1bfd95ba..27198ae2021 100644 --- a/examples/models/llama/tests/test_ring_attention.py +++ b/examples/models/llama/tests/test_ring_attention.py @@ -6,11 +6,26 @@ import copy import unittest +from enum import Enum import torch from executorch.examples.models.llama.attention import AttentionMHA, RingKVCache from executorch.examples.models.llama.model_args import ModelArgs from executorch.examples.models.llama.rope import Rope +from executorch.examples.models.llama.source_transformation.custom_kv_cache import ( + CustomKVCache, + CustomRingKVCache, + QuantizedKVCache, + QuantizedRingKVCache, + replace_kv_cache_with_custom_kv_cache, + replace_kv_cache_with_quantized_kv_cache, +) + + +class KVCacheType(Enum): + REGULAR = "regular" + QUANTIZED = "quantized" + CUSTOM = "custom" class TestRingAttention(unittest.TestCase): @@ -27,7 +42,9 @@ def setUp(self): self.dtype = torch.float32 self.device = "cpu" - def _create_baseline_attention(self, seq_len: int): + def _create_baseline_attention( + self, seq_len: int, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Create baseline attention with regular KV cache.""" # Create model args self.args = ModelArgs( @@ -49,24 +66,54 @@ def _create_baseline_attention(self, seq_len: int): seq_len, self.max_context_len, self.sliding_window ) - return attention - - def _create_ring_attention(self, attention): + # Replace the KV cache with the specified type + if kv_cache_type == KVCacheType.QUANTIZED: + # Create a copy to avoid modifying the original attention + attention_copy = copy.deepcopy(attention) + # Replace KVCache with QuantizedKVCache + replace_kv_cache_with_quantized_kv_cache(attention_copy) + return attention_copy + elif kv_cache_type == KVCacheType.CUSTOM: + # Create a copy to avoid modifying the original attention + attention_copy = copy.deepcopy(attention) + # Replace KVCache with CustomKVCache + replace_kv_cache_with_custom_kv_cache(attention_copy) + return attention_copy + else: + return attention + + def _create_ring_attention( + self, attention, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Create attention with ring buffer KV cache.""" assert self.sliding_window is not None # Create RoPE instance self.rope = Rope(self.args) baseline_attention = copy.deepcopy(attention) - # Replace the KV cache with a ring buffer KV cache - baseline_attention.kv_cache = RingKVCache( - self.args.max_batch_size, - self.sliding_window, - self.n_kv_heads, - self.head_dim, - self.args.enable_dynamic_shape, - self.dtype, - ) + # Replace the KV cache with a ring buffer KV cache based on the type + if isinstance(baseline_attention.kv_cache, QuantizedKVCache): + # Replace QuantizedKVCache with QuantizedRingKVCache + baseline_attention.kv_cache = QuantizedRingKVCache.from_quantized_kv_cache( + baseline_attention.kv_cache, + self.sliding_window, + ) + elif isinstance(baseline_attention.kv_cache, CustomKVCache): + # Replace CustomKVCache with CustomRingKVCache + baseline_attention.kv_cache = CustomRingKVCache.from_custom_kv_cache( + baseline_attention.kv_cache, + self.sliding_window, + ) + else: + # Replace regular KVCache with RingKVCache + baseline_attention.kv_cache = RingKVCache( + self.args.max_batch_size, + self.sliding_window, + self.n_kv_heads, + self.head_dim, + self.args.enable_dynamic_shape, + self.dtype, + ) return baseline_attention def _create_sliding_window_mask(self, seq_len, context_len, window_size): @@ -79,12 +126,20 @@ def _create_sliding_window_mask(self, seq_len, context_len, window_size): mask[i, start_idx : pos + 1] = 0 return mask - def test_single_token_processing(self): + def _run_test_with_kv_cache_type(self, test_func, kv_cache_type: KVCacheType): + """Run a test with the specified KV cache type.""" + original_test_name = test_func.__name__ + print(f"\nRunning {original_test_name} with {kv_cache_type.value} KV cache") + test_func(kv_cache_type) + + def test_single_token_processing( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer and baseline produce the same output for single token processing.""" seq_len = 10 self.sliding_window = 4 - baseline_attn = self._create_baseline_attention(seq_len) - ring_attn = self._create_ring_attention(baseline_attn) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process tokens one by one for pos in range(seq_len): @@ -109,17 +164,31 @@ def test_single_token_processing(self): f"Outputs differ at position {pos}", ) - def test_sliding_window_attention(self): + def test_single_token_processing_quantized(self): + """Test single token processing with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_single_token_processing, KVCacheType.QUANTIZED + ) + + def test_single_token_processing_custom(self): + """Test single token processing with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_single_token_processing, KVCacheType.CUSTOM + ) + + def test_sliding_window_attention( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer with sliding window size produces the same output as baseline with sliding window mask.""" self.sliding_window = 4 self.max_context_len = 16 seq_len = 10 # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(seq_len) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process tokens one by one for pos in range(seq_len): @@ -143,16 +212,32 @@ def test_sliding_window_attention(self): f"Outputs differ at position {pos}", ) - def test_ring_buffer_wrapping(self): + def test_sliding_window_attention_quantized(self): + """Test sliding window attention with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_sliding_window_attention, KVCacheType.QUANTIZED + ) + + def test_sliding_window_attention_custom(self): + """Test sliding window attention with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_sliding_window_attention, KVCacheType.CUSTOM + ) + + def test_ring_buffer_wrapping( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test that ring buffer correctly wraps around and maintains correct attention patterns.""" self.sliding_window = 3 self.max_context_len = 15 # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(self.max_context_len) + baseline_attn = self._create_baseline_attention( + self.max_context_len, kv_cache_type + ) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) # Process enough tokens to cause wrapping seq_len = 1 @@ -188,7 +273,21 @@ def test_ring_buffer_wrapping(self): f"Expected positions {expected_positions}, got {cache_positions}", ) - def test_large_context_with_sliding_window(self): + def test_ring_buffer_wrapping_quantized(self): + """Test ring buffer wrapping with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_ring_buffer_wrapping, KVCacheType.QUANTIZED + ) + + def test_ring_buffer_wrapping_custom(self): + """Test ring buffer wrapping with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_ring_buffer_wrapping, KVCacheType.CUSTOM + ) + + def test_large_context_with_sliding_window( + self, kv_cache_type: KVCacheType = KVCacheType.REGULAR + ): """Test with a large context length and compare baseline with sliding window to ring buffer.""" # Use a larger context length and sliding window for this test self.max_context_len = 64 @@ -197,10 +296,10 @@ def test_large_context_with_sliding_window(self): token_lens = [8, 1, 3, 2, 1, 1, 1, 1, 7, 1, 5, 1, 1, 1, 4, 1, 1, 2, 1, 1] seq_len = sum(token_lens) # Create baseline attention with full context length - baseline_attn = self._create_baseline_attention(seq_len) + baseline_attn = self._create_baseline_attention(seq_len, kv_cache_type) # Create ring attention with sliding window size - ring_attn = self._create_ring_attention(baseline_attn) + ring_attn = self._create_ring_attention(baseline_attn, kv_cache_type) pos = 0 for token_len in token_lens: @@ -224,3 +323,15 @@ def test_large_context_with_sliding_window(self): f"Outputs differ at position {pos} with max difference {(baseline_out - ring_out).abs().max()}", ) pos += token_len + + def test_large_context_with_sliding_window_quantized(self): + """Test large context with sliding window with QuantizedKVCache.""" + self._run_test_with_kv_cache_type( + self.test_large_context_with_sliding_window, KVCacheType.QUANTIZED + ) + + def test_large_context_with_sliding_window_custom(self): + """Test large context with sliding window with CustomKVCache.""" + self._run_test_with_kv_cache_type( + self.test_large_context_with_sliding_window, KVCacheType.CUSTOM + )