diff --git a/mlx_lm/models/cache.py b/mlx_lm/models/cache.py index 6460ebfdb..d8851ffcb 100644 --- a/mlx_lm/models/cache.py +++ b/mlx_lm/models/cache.py @@ -676,7 +676,7 @@ def empty(self): @property def nbytes(self): - return sum(c.nbytes for c in self.cache) + return sum(c.nbytes for c in self.cache if c is not None) class ChunkedKVCache(_BaseCache): diff --git a/mlx_lm/server.py b/mlx_lm/server.py index 36695aa88..1530622c8 100644 --- a/mlx_lm/server.py +++ b/mlx_lm/server.py @@ -189,6 +189,7 @@ class CacheEntry: prompt_cache: List[Any] count: int nbytes: int + is_snapshot: bool = False # True if this is a snapshot copy @dataclass class SearchResult: @@ -198,9 +199,10 @@ class SearchResult: longer: List[int] common_prefix: int - def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63): + def __init__(self, max_size: int = 10, max_bytes: int = 1 << 63, checkpoint_interval: int = 8192): self.max_size = max_size self.max_bytes = max_bytes + self.checkpoint_interval = checkpoint_interval self._cache = {} self._lru = deque() self._n_bytes = 0 @@ -212,6 +214,37 @@ def __len__(self): def nbytes(self): return self._n_bytes + def _find_checkpoint_positions(self, tokens_len, boundary_positions=None): + """Find positions where we should create checkpoint snapshots. + + Creates checkpoints at message boundaries that fall near checkpoint_interval + boundaries (e.g., around 8k tokens). This ensures we snapshot at natural + conversation break points without creating too many snapshots. + + Returns list of positions (not including 0 or tokens_len). + """ + positions = set() + + if not boundary_positions: + return [] + + # Find message boundaries near checkpoint intervals + # For each 8k interval, find the closest message boundary + last_checkpoint = 0 + for interval_start in range(self.checkpoint_interval, tokens_len, self.checkpoint_interval): + # Find message boundaries within this interval + candidates = [p for p in boundary_positions + if last_checkpoint < p < tokens_len and p <= interval_start + self.checkpoint_interval // 2] + + if candidates: + # Pick the last boundary in this range (most complete message) + best = max(candidates) + if best > last_checkpoint: + positions.add(best) + last_checkpoint = best + + return sorted(positions) + def _search(self, model, tokens): """Search the cache for a prompt cache. Return exact or close match.""" if model not in self._cache: @@ -274,28 +307,56 @@ def _delete(self, model, tokens): logging.debug(f"[LRUPromptCache] Removed {cache_bytes} bytes from the cache") - def _extract(self, model, tokens): + def _extract(self, model, tokens, keep_original=False): cache_entry = self._get(model, tokens) + is_snapshot = cache_entry.is_snapshot + + # For snapshots, always extract (they are already copies) + # For non-snapshots with count=1, extract or keep based on keep_original if cache_entry.count == 1: + if keep_original and not is_snapshot: + # Return the SAME cache object (not a copy) for hybrid models + # This allows the cache to be updated in place and shared across positions + return self.CacheEntry( + cache_entry.prompt_cache, 1, cache_entry.nbytes + ) self._delete(model, tokens) self._lru.remove((model, tokens)) return cache_entry cache_entry.count -= 1 + # Return the SAME cache object (not a copy) return self.CacheEntry( - copy.deepcopy(cache_entry.prompt_cache), 1, cache_entry.nbytes + cache_entry.prompt_cache, 1, cache_entry.nbytes ) def fetch_nearest_cache(self, model, tokens): result = self._search(model, tokens) + requested_len = len(tokens) + if result.exact is not None: cache_entry = self._extract(result.model, result.exact) - return cache_entry.prompt_cache, [] + logging.debug(f"[LRUPromptCache] EXACT match: {requested_len} tokens") + return cache_entry.prompt_cache, [], result.exact if result.shorter is not None: - cache_entry = self._extract(result.model, result.shorter) + cache_entry = self._get(model, result.shorter) + is_snapshot = cache_entry.is_snapshot prefix_len = len(result.shorter) - return cache_entry.prompt_cache, tokens[prefix_len:] + remaining = len(tokens) - prefix_len + + if is_snapshot: + # Snapshots are already copies - extract them (don't keep original) + # Don't return old_position since we're branching from a snapshot + cache_entry = self._extract(result.model, result.shorter, keep_original=False) + logging.debug(f"[LRUPromptCache] CHECKPOINT match: {prefix_len} cached (snapshot), {remaining} to process (requested {requested_len})") + return cache_entry.prompt_cache, tokens[prefix_len:], None + else: + # Non-snapshot: return the SAME cache object for efficient reuse + # Also return the old position so we can move the cache later + cache_entry = self._extract(result.model, result.shorter, keep_original=True) + logging.debug(f"[LRUPromptCache] SHORTER match: {prefix_len} cached, {remaining} to process (requested {requested_len})") + return cache_entry.prompt_cache, tokens[prefix_len:], result.shorter if result.longer is not None: cache_entry = self._get(result.model, result.longer) @@ -304,11 +365,31 @@ def fetch_nearest_cache(self, model, tokens): prefix = min(len(tokens) - 1, result.common_prefix) num_to_trim = len(result.longer) - prefix trim_prompt_cache(cache, num_to_trim) - return cache, tokens[prefix:] + logging.debug(f"[LRUPromptCache] LONGER match: trimmed {num_to_trim} tokens") + return cache, tokens[prefix:], None + + logging.debug(f"[LRUPromptCache] NO match for {requested_len} tokens") + return None, tokens, None - return None, tokens + def insert_cache(self, model, tokens, prompt_cache, boundary_positions=None, old_position=None): + """Insert cache at the full token position with checkpoint snapshots. + + Args: + model: The model identifier + tokens: The full token sequence + prompt_cache: The cache to store + boundary_positions: Optional list of positions where message boundaries occur + old_position: If provided, delete the cache at this old position (for cache moves) + """ + # If we're extending a cache, delete the old position first + if old_position is not None and old_position != tokens: + self._delete(model, old_position) + try: + self._lru.remove((model, old_position)) + except ValueError: + pass # Already not in LRU + logging.debug(f"[LRUPromptCache] Moved cache from {len(old_position)} to {len(tokens)}") - def insert_cache(self, model, tokens, prompt_cache): if model not in self._cache: self._cache[model] = {} current = self._cache[model] @@ -318,22 +399,72 @@ def insert_cache(self, model, tokens, prompt_cache): current = current[tok] if "cache" in current: - current["cache"].count += 1 + # Update existing cache with new state + old_bytes = current["cache"].nbytes + new_bytes = sum(c.nbytes for c in prompt_cache) + current["cache"] = self.CacheEntry(prompt_cache, current["cache"].count + 1, new_bytes) + self._n_bytes += (new_bytes - old_bytes) self._lru.remove((model, tokens)) + logging.debug(f"[LRUPromptCache] Updating cache at position {len(tokens)}, bytes: {old_bytes} -> {new_bytes}") else: cache_bytes = sum(c.nbytes for c in prompt_cache) current["cache"] = self.CacheEntry(prompt_cache, 1, cache_bytes) self._n_bytes += cache_bytes - logging.debug(f"[LRUPromptCache] Adding {cache_bytes} to the cache") + logging.debug(f"[LRUPromptCache] Adding {cache_bytes} bytes at position {len(tokens)}") self._lru.append((model, tokens)) - if len(self._lru) > self.max_size: - model, tokens = self._lru.popleft() + + # Create checkpoint snapshots at interval and message boundaries + checkpoint_positions = self._find_checkpoint_positions(len(tokens), boundary_positions) + for pos in checkpoint_positions: + self._create_checkpoint_snapshot(model, tokens, prompt_cache, pos) + + # Evict oldest entries first when limits exceeded (LRU eviction) + # This removes caches that haven't been used recently + while len(self._lru) > self.max_size: + model, tokens = self._lru.popleft() # Remove oldest (first in) self._delete(model, tokens) while self._n_bytes > self.max_bytes and len(self._lru) > 1: - model, tokens = self._lru.popleft() + model, tokens = self._lru.popleft() # Remove oldest (first in) self._delete(model, tokens) + def _create_checkpoint_snapshot(self, model, tokens, prompt_cache, position): + """Create a deep copy snapshot of the cache at a checkpoint position. + + For hybrid models (KVCache + ArraysCache), we: + - Deep copy the entire cache state + - Trim KVCache layers to the checkpoint position + - ArraysCache layers are copied as-is (they maintain full state) + """ + if model not in self._cache: + self._cache[model] = {} + + checkpoint_tokens = tokens[:position] + current = self._cache[model] + for tok in checkpoint_tokens: + if tok not in current: + current[tok] = {} + current = current[tok] + + # Only create snapshot if one doesn't already exist at this position + if "cache" in current: + return + + # Deep copy the cache + snapshot = copy.deepcopy(prompt_cache) + + # Trim KVCache layers to the checkpoint position + # This works for trimmable caches (KVCache), non-trimmable caches (ArraysCache) stay as-is + tokens_to_trim = len(tokens) - position + if tokens_to_trim > 0 and can_trim_prompt_cache(snapshot): + trim_prompt_cache(snapshot, tokens_to_trim) + + snapshot_bytes = sum(c.nbytes for c in snapshot) + current["cache"] = self.CacheEntry(snapshot, 1, snapshot_bytes, is_snapshot=True) + self._n_bytes += snapshot_bytes + self._lru.append((model, checkpoint_tokens)) + logging.debug(f"[LRUPromptCache] Created checkpoint snapshot at position {position} ({snapshot_bytes} bytes)") + def trim_to( self, *, n_sequences: Optional[int] = None, n_bytes: Optional[int] = None ): @@ -714,6 +845,80 @@ def _tokenize(self, tokenizer, request, args): else: return tokenizer.encode(request.prompt) + def _tokenize_for_cache_key(self, tokenizer, request, args): + """Tokenize without generation prompt for cache key computation. + + This ensures prompt chaining works correctly by using the message + content only (without generation suffix) as the cache key. + """ + if request.request_type == "chat": + messages = request.messages + tools = request.tools + + if tokenizer.has_chat_template: + chat_template_args = self.model_provider.cli_args.chat_template_args + if args.chat_template_kwargs: + chat_template_args = chat_template_args.copy() + chat_template_args.update(args.chat_template_kwargs) + return tokenizer.apply_chat_template( + messages, + tools=tools, + add_generation_prompt=False, + tokenize=True, + **chat_template_args, + ) + else: + return tokenizer.encode(convert_chat(messages, request.role_mapping)) + else: + return tokenizer.encode(request.prompt) + + def _compute_message_boundaries(self, tokenizer, request, args): + """Compute token positions where each message ends in the cache key. + + This is used to create checkpoint snapshots at message boundaries, + allowing cache reuse when conversations branch in different directions. + + Args: + tokenizer: The tokenizer to use + request: The request containing messages + args: Generation arguments (for chat_template_kwargs) + + Returns: + List of token positions where messages end + """ + if request.request_type != "chat": + return [] + + messages = request.messages + tools = request.tools + if not messages or not tokenizer.has_chat_template: + return [] + + chat_template_args = self.model_provider.cli_args.chat_template_args + if args.chat_template_kwargs: + chat_template_args = chat_template_args.copy() + chat_template_args.update(args.chat_template_kwargs) + + boundaries = [] + + # Tokenize progressively to find message boundary positions + for i in range(1, len(messages)): + try: + partial_tokens = tokenizer.apply_chat_template( + messages[:i], + tools=tools, + add_generation_prompt=False, + tokenize=True, + **chat_template_args, + ) + if partial_tokens: + boundaries.append(len(partial_tokens)) + except Exception: + # If tokenization fails for partial messages, skip this boundary + continue + + return boundaries + def _is_batchable(self, args): if not self.model_provider.is_batchable: return False @@ -722,6 +927,71 @@ def _is_batchable(self, args): return True + def _find_cache_boundaries(self, tokens, tokenizer, min_chunk_size=10): + """Find positions in tokens where we should create cache boundaries. + + This identifies key positions (like after the system message) where + storing a separate cache enables better prompt chaining when + conversations branch in different directions. + + Args: + tokens: The tokenized prompt + tokenizer: The tokenizer to use + min_chunk_size: Minimum tokens before a boundary is considered + + Returns: + List of positions where cache boundaries should be created + """ + boundaries = [] + + # Find the end-of-message token (typically <|im_end|> or similar) + end_token = None + try: + im_end_tokens = tokenizer.encode("<|im_end|>", add_special_tokens=False) + if len(im_end_tokens) == 1: + end_token = im_end_tokens[0] + except: + pass + + if end_token is None: + # Try alternative end tokens + for end_str in ["", "<|end|>", "[/INST]"]: + try: + end_tokens = tokenizer.encode(end_str, add_special_tokens=False) + if len(end_tokens) == 1: + end_token = end_tokens[0] + break + except: + pass + + if end_token is None: + return boundaries + + # Find positions after each complete message + # We only store at the FIRST message boundary (after system message) + # and at the last position (full prompt) + for i, tok in enumerate(tokens): + if tok == end_token: + # Position after this token and any following newline + pos = i + 1 + # Skip following newline if present + if pos < len(tokens): + try: + next_decoded = tokenizer.decode([tokens[pos]]) + if next_decoded in ['\n', '']: + pos += 1 + except: + pass + + # Only add if it's a significant chunk + if pos >= min_chunk_size and pos < len(tokens): + boundaries.append(pos) + # Only store at the first message boundary (usually after system message) + # This is the most common branching point + break + + return boundaries + def _generate(self): current_model = None current_sampling = None @@ -770,6 +1040,13 @@ def progress_callback(info): ): try: prompt = self._tokenize(current_tokenizer, request, args) + cache_key_prompt = self._tokenize_for_cache_key( + current_tokenizer, request, args + ) + # Compute message boundary positions for checkpoint snapshots + boundary_positions = self._compute_message_boundaries( + current_tokenizer, request, args + ) except Exception as e: rqueue.put(e) continue @@ -792,12 +1069,19 @@ def progress_callback(info): ) rqueue.put(ctx) - cache, rest = self.prompt_cache.fetch_nearest_cache( - current_model_key, prompt + cache, rest, old_position = self.prompt_cache.fetch_nearest_cache( + current_model_key, cache_key_prompt ) - ctx.prompt_cache_count = len(prompt) - len(rest) + ctx.prompt_cache_count = len(cache_key_prompt) - len(rest) + + # Include generation suffix in rest for batch processing + gen_suffix_len = len(prompt) - len(cache_key_prompt) + if gen_suffix_len > 0: + rest = list(rest) + list(prompt[-gen_suffix_len:]) + if cache is None: cache = make_prompt_cache(self.model_provider.model) + old_position = None ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes logging.info( @@ -813,7 +1097,9 @@ def progress_callback(info): ) batch_results[uid] = { "ctx": ctx, - "cache_key": prompt[:], + "cache_key": list(cache_key_prompt), + "boundary_positions": boundary_positions, + "old_position": old_position, "rqueue": rqueue, "detokenizer": tokenizer.detokenizer, } @@ -884,7 +1170,6 @@ def progress_callback(info): for r in responses: result = batch_results[r.uid] - result["cache_key"].append(r.token) if r.finish_reason != "stop": result["detokenizer"].add_token(r.token) @@ -903,7 +1188,9 @@ def progress_callback(info): if r.finish_reason is not None: result["rqueue"].put(None) self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], r.prompt_cache + current_model_key, result["cache_key"], r.prompt_cache, + boundary_positions=result.get("boundary_positions"), + old_position=result.get("old_position") ) del batch_results[r.uid] @@ -921,7 +1208,9 @@ def progress_callback(info): continue result = batch_results[uid] self.prompt_cache.insert_cache( - current_model_key, result["cache_key"], prompt_cache + current_model_key, result["cache_key"], prompt_cache, + boundary_positions=result.get("boundary_positions"), + old_position=result.get("old_position") ) del batch_results[uid] @@ -938,9 +1227,15 @@ def progress(tokens_processed, tokens_total): tokenizer = self.model_provider.tokenizer draft_model = self.model_provider.draft_model - # Prepare the prompt + # Prepare the prompt for generation (with generation suffix) prompt = self._tokenize(tokenizer, request, args) + # Prepare the cache key prompt (without generation suffix for prompt chaining) + cache_key_prompt = self._tokenize_for_cache_key(tokenizer, request, args) + + # Compute message boundary positions for checkpoint snapshots + boundary_positions = self._compute_message_boundaries(tokenizer, request, args) + # Start the generation context ctx = GenerationContext( has_tool_calling=tokenizer.has_tool_calling, @@ -968,16 +1263,27 @@ def progress(tokens_processed, tokens_total): sampler = _make_sampler(args, tokenizer) logits_processors = _make_logits_processors(args) - # Load the KV cache - cache, rest = self.prompt_cache.fetch_nearest_cache( - self.model_provider.model_key, prompt + # Load the KV cache using cache_key_prompt (without gen suffix) + # Returns: cache, remaining tokens, old_position (for cache move) + cache, rest, old_position = self.prompt_cache.fetch_nearest_cache( + self.model_provider.model_key, cache_key_prompt ) - ctx.prompt_cache_count = len(prompt) - len(rest) - cache_key = prompt[:] + ctx.prompt_cache_count = len(cache_key_prompt) - len(rest) + + # Store cache at cache_key_prompt position for prompt chaining support + cache_key = list(cache_key_prompt) + + # Compute the generation suffix (prompt - cache_key_prompt) + gen_suffix_len = len(prompt) - len(cache_key_prompt) + if gen_suffix_len > 0: + # If we have a cache hit, we need to include the gen suffix in the rest + rest = list(rest) + list(prompt[-gen_suffix_len:]) + if cache is None: cache = make_prompt_cache(self.model_provider.model) if self.model_provider.draft_model is not None: cache += make_prompt_cache(self.model_provider.draft_model) + old_position = None # New cache, no old position to delete ncaches, nbytes = len(self.prompt_cache), self.prompt_cache.nbytes logging.info(f"We have {ncaches} kv caches that take {nbytes/1e9:.2f} GB") @@ -1006,7 +1312,6 @@ def progress(tokens_processed, tokens_total): ), ) ) - cache_key.append(gen.token) if ctx._should_stop: if self._is_distributed: @@ -1015,9 +1320,12 @@ def progress(tokens_processed, tokens_total): rqueue.put(None) - # Save the KV cache again + # Save the KV cache with checkpoint snapshots at message boundaries + # This enables cache reuse when conversations branch in different directions self.prompt_cache.insert_cache( - self.model_provider.model_key, cache_key, cache + self.model_provider.model_key, cache_key, cache, + boundary_positions=boundary_positions, + old_position=old_position ) except Exception as e: diff --git a/tests/test_hybrid_cache.py b/tests/test_hybrid_cache.py new file mode 100644 index 000000000..4ddd761bd --- /dev/null +++ b/tests/test_hybrid_cache.py @@ -0,0 +1,793 @@ +#!/usr/bin/env python3 +""" +Test suite for hybrid model caching in mlx_lm server. + +This tests the prompt chaining functionality for hybrid models (like Qwen3.5) +that use a mix of KVCache and ArraysCache. +""" + +import json +import subprocess +import sys +import time +import urllib.request +import urllib.error + +# Test configuration +SERVER_HOST = "127.0.0.1" +SERVER_PORT = 1117 +MODEL_PATH = "/Users/sombrax/.lmstudio/models/sombra/QWEN-3.5-MXFP4" +CHAT_TEMPLATE_PATH = "/Users/sombrax/VibeCoding/mlx_server/templates/qwen35.jinja" +SERVER_STARTUP_TIMEOUT = 120 # seconds +REQUEST_TIMEOUT = 180 # seconds + + +def start_server(): + """Start the mlx_lm server.""" + with open(CHAT_TEMPLATE_PATH) as f: + chat_template = f.read() + + cmd = [ + "python3", "-m", "mlx_lm.server", + "--host", SERVER_HOST, + "--port", str(SERVER_PORT), + "--model", MODEL_PATH, + "--trust-remote-code", + "--log-level", "INFO", + "--chat-template", chat_template, + ] + + proc = subprocess.Popen( + cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True + ) + + # Wait for server to be ready + start_time = time.time() + while time.time() - start_time < SERVER_STARTUP_TIMEOUT: + try: + req = urllib.request.urlopen( + f"http://{SERVER_HOST}:{SERVER_PORT}/health", + timeout=5 + ) + if req.status == 200: + return proc + except (urllib.error.URLError, ConnectionRefusedError): + pass + time.sleep(2) + + proc.terminate() + raise RuntimeError("Server failed to start within timeout") + + +def stop_server(proc): + """Stop the mlx_lm server.""" + proc.terminate() + try: + proc.wait(timeout=10) + except subprocess.TimeoutExpired: + proc.kill() + + +def make_request(messages, max_tokens=50, enable_thinking=False): + """Make a chat completion request.""" + data = { + "model": MODEL_PATH, + "messages": messages, + "max_tokens": max_tokens, + "temperature": 0.0, + "chat_template_kwargs": {"enable_thinking": enable_thinking} + } + + req = urllib.request.Request( + f"http://{SERVER_HOST}:{SERVER_PORT}/v1/chat/completions", + data=json.dumps(data).encode(), + headers={"Content-Type": "application/json"}, + method="POST" + ) + + try: + response = urllib.request.urlopen(req, timeout=REQUEST_TIMEOUT) + return json.loads(response.read().decode()) + except urllib.error.HTTPError as e: + return {"error": e.read().decode()} + + +def test_basic_generation(): + """Test basic generation works.""" + print("\n=== Test: Basic Generation ===") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Say hello"} + ] + response = make_request(messages, enable_thinking=False) + + if "error" in response: + print(f"FAILED: {response['error']}") + return False + + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + + print(f"Prompt tokens: {response.get('usage', {}).get('prompt_tokens', 'N/A')}") + print(f"Cached tokens: {cached}") + print(f"Generated: {content[:100]}") + + assert cached == 0, "First request should have 0 cached tokens" + assert len(content) > 0, "Should generate content" + + print("PASSED") + return True + + +def test_prompt_chaining(): + """Test that prompt chaining uses cache.""" + print("\n=== Test: Prompt Chaining ===") + + # First request + messages1 = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ] + response1 = make_request(messages1, enable_thinking=False) + + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + + content1 = response1.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"First response: {content1[:100]}") + + # Second request (extending the conversation) + messages2 = messages1 + [ + {"role": "assistant", "content": content1.split('.')[0] + "."}, # Use part of the actual response + {"role": "user", "content": "What about Germany?"} + ] + response2 = make_request(messages2, enable_thinking=False) + + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + content2 = response2.get("choices", [{}])[0].get("message", {}).get("content", "") + + print(f"Second request prompt tokens: {response2.get('usage', {}).get('prompt_tokens', 'N/A')}") + print(f"Second request cached tokens: {cached2}") + print(f"Second response: {content2[:100]}") + + assert cached2 > 0, f"Second request should use cache, but cached_tokens={cached2}" + + print("PASSED") + return True + + +def test_extended_prompt_chaining(): + """Test extended prompt chaining with multiple turns.""" + print("\n=== Test: Extended Prompt Chaining ===") + + # Use unique system prompt to avoid interference from other tests + import time + unique_id = str(int(time.time() * 1000)) + + # Build up a conversation + messages = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + ] + + cached_tokens_history = [] + + # First turn + messages.append({"role": "user", "content": "My name is Alice."}) + response = make_request(messages, enable_thinking=False) + if "error" in response: + print(f"FAILED (turn 1): {response['error']}") + return False + content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + messages.append({"role": "assistant", "content": content.split('.')[0] + "."}) + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + cached_tokens_history.append(cached) + print(f"Turn 1: cached={cached}, response={content[:50]}...") + + # Second turn + messages.append({"role": "user", "content": "What is my name?"}) + response = make_request(messages, enable_thinking=False) + if "error" in response: + print(f"FAILED (turn 2): {response['error']}") + return False + content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + messages.append({"role": "assistant", "content": content.split('.')[0] + "."}) + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + cached_tokens_history.append(cached) + print(f"Turn 2: cached={cached}, response={content[:50]}...") + + # Third turn + messages.append({"role": "user", "content": "Can you count to 5?"}) + response = make_request(messages, enable_thinking=False) + if "error" in response: + print(f"FAILED (turn 3): {response['error']}") + return False + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + cached_tokens_history.append(cached) + print(f"Turn 3: cached={cached}") + + # Verify cache is being used + assert cached_tokens_history[0] == 0, "First request should have 0 cached tokens" + assert cached_tokens_history[1] > 0, "Second request should use cache" + assert cached_tokens_history[2] > cached_tokens_history[1], "Third request should use more cache" + + print("PASSED") + return True + + +def test_cache_persistence(): + """Test that cache persists when extending the same conversation.""" + print("\n=== Test: Cache Persistence (Same Conversation Extension) ===") + + # Use unique system prompt to avoid interference from other tests + import time + unique_id = str(int(time.time() * 1000)) + + # First request + messages1 = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + {"role": "user", "content": "My favorite color is blue."} + ] + response1 = make_request(messages1, enable_thinking=False) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + content1 = response1.get("choices", [{}])[0].get("message", {}).get("content", "") + + # Second request - extend with actual generated response (should use cache from first) + messages2 = messages1 + [ + {"role": "assistant", "content": content1.strip()}, + {"role": "user", "content": "What is my favorite color?"} + ] + response2 = make_request(messages2, enable_thinking=False) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + content2 = response2.get("choices", [{}])[0].get("message", {}).get("content", "") + + # Third request - continue extending with actual generated response (should use cache from second) + messages3 = messages2 + [ + {"role": "assistant", "content": content2.strip()}, + {"role": "user", "content": "Do you remember my favorite color?"} + ] + response3 = make_request(messages3, enable_thinking=False) + if "error" in response3: + print(f"FAILED (request 3): {response3['error']}") + return False + cached3 = response3.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + + print(f"Request 1 cached: 0 (first request)") + print(f"Request 2 cached: {cached2}") + print(f"Request 3 cached: {cached3}") + + # Each subsequent request should have more cached tokens than the previous + assert cached2 > 0, f"Request 2 should use cache, but cached_tokens={cached2}" + assert cached3 > cached2, f"Request 3 should have more cached tokens than request 2, but cached3={cached3} vs cached2={cached2}" + + print("PASSED") + return True + + +def test_cache_invalidation(): + """Test that different conversations don't share cache.""" + print("\n=== Test: Cache Invalidation ===") + + # Use unique system prompts to avoid interference + import time + unique_id = str(int(time.time() * 1000)) + + # Conversation A + messages_a = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}A]"}, + {"role": "user", "content": "Remember the number 42."} + ] + response_a1 = make_request(messages_a, enable_thinking=False) + if "error" in response_a1: + print(f"FAILED (conv A): {response_a1['error']}") + return False + + # Conversation B (different system prompt) + messages_b = [ + {"role": "system", "content": f"You are a pirate assistant. Arr! [{unique_id}B]"}, + {"role": "user", "content": "Remember the number 42."} + ] + response_b = make_request(messages_b, enable_thinking=False) + if "error" in response_b: + print(f"FAILED (conv B): {response_b['error']}") + return False + + cached_b = response_b.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + + # Conversation B should not use cache from conversation A (different system prompt) + assert cached_b == 0, f"Different conversations should not share cache, but cached_tokens={cached_b}" + + print("PASSED") + return True + + +def test_conversation_branching(): + """Test that extending a conversation properly moves the cache. + + Note: With the 'move' caching strategy, branching from intermediate points + (like the system prompt) is not supported. Each conversation path has its + own cache that gets moved as the conversation extends. + """ + print("\n=== Test: Conversation Extension (Cache Move) ===") + + # Use unique system prompt to avoid interference + import time + unique_id = str(int(time.time() * 1000)) + system_prompt = f"You are a math bot. Answer briefly. [{unique_id}]" + + # Request 1: Establish conversation + messages1 = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "1+1=?"} + ] + response1 = make_request(messages1, enable_thinking=False) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + + content1 = response1.get("choices", [{}])[0].get("message", {}).get("content", "") + cached1 = response1.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Request 1 (first): {cached1} cached") + + # Request 2: Extend conversation (should use cache from request 1) + messages2 = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "1+1=?"}, + {"role": "assistant", "content": content1.strip()}, + {"role": "user", "content": "2+2=?"} + ] + response2 = make_request(messages2, enable_thinking=False) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + + content2 = response2.get("choices", [{}])[0].get("message", {}).get("content", "") + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Request 2 (extend): {cached2} cached") + + # Request 3: Extend further (should use cache from request 2) + messages3 = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "1+1=?"}, + {"role": "assistant", "content": content1.strip()}, + {"role": "user", "content": "2+2=?"}, + {"role": "assistant", "content": content2.strip()}, + {"role": "user", "content": "3+3=?"} + ] + response3 = make_request(messages3, enable_thinking=False) + if "error" in response3: + print(f"FAILED (request 3): {response3['error']}") + return False + + cached3 = response3.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Request 3 (extend further): {cached3} cached") + + # Verify caching behavior: each extension should have more cached tokens + assert cached1 == 0, "First request should have 0 cached tokens" + assert cached2 > cached1, f"Request 2 should have more cache than request 1" + assert cached3 > cached2, f"Request 3 should have more cache than request 2" + + print("PASSED") + return True + + +def test_cache_survives_multiple_branches(): + """Test that multiple conversation extensions work correctly. + + Note: With the 'move' caching strategy, we maintain one cache per + conversation path, not per system prompt. + """ + print("\n=== Test: Multiple Extensions ===") + + # Use unique system prompt to avoid interference + import time + unique_id = str(int(time.time() * 1000)) + system_prompt = f"You are a helpful assistant. Be very brief. [{unique_id}]" + + # Establish conversation + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": "First message"} + ] + response = make_request(messages, enable_thinking=False) + if "error" in response: + print(f"FAILED (initial): {response['error']}") + return False + + last_content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + cached_tokens_list = [0] + + # Multiple extensions should each use the previous cache + for i in range(5): + messages.append({"role": "assistant", "content": last_content.strip()}) + messages.append({"role": "user", "content": f"Message {i+2}"}) + response = make_request(messages, enable_thinking=False) + if "error" in response: + print(f"FAILED (extension {i}): {response['error']}") + return False + + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + cached_tokens_list.append(cached) + last_content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"Extension {i+1}: {cached} cached") + + # Each extension should have more cached tokens than the previous + for i in range(1, len(cached_tokens_list)): + assert cached_tokens_list[i] > cached_tokens_list[i-1], \ + f"Extension {i} should have more cache than extension {i-1}" + + print("PASSED") + return True + + +def test_exact_match_reuse(): + """Test that identical requests reuse exact cache match.""" + print("\n=== Test: Exact Match Reuse ===") + + # Use unique system prompt to avoid interference + import time + unique_id = str(int(time.time() * 1000)) + + messages = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + {"role": "user", "content": "What is 2+2?"} + ] + + # First request + response1 = make_request(messages, enable_thinking=False) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + cached1 = response1.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"First request cached: {cached1}") + + # Second identical request (should get exact match) + response2 = make_request(messages, enable_thinking=False) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + prompt_tokens2 = response2.get("usage", {}).get("prompt_tokens", 0) + print(f"Second request cached: {cached2}, prompt_tokens: {prompt_tokens2}") + + # Second request should have all tokens cached (exact match) + assert cached1 == 0, "First request should have 0 cached tokens" + assert cached2 == prompt_tokens2, f"Second request should have all tokens cached (exact match), got {cached2}/{prompt_tokens2}" + + print("PASSED") + return True + + +def test_thinking_mode_caching(): + """Test that thinking mode works with caching.""" + print("\n=== Test: Thinking Mode Caching ===") + + # Use unique system prompt + import time + unique_id = str(int(time.time() * 1000)) + + messages1 = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + {"role": "user", "content": "Think about what 5+5 equals."} + ] + + # First request with thinking enabled + response1 = make_request(messages1, max_tokens=100, enable_thinking=True) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + cached1 = response1.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + content1 = response1.get("choices", [{}])[0].get("message", {}).get("content", "") + reasoning1 = response1.get("choices", [{}])[0].get("message", {}).get("reasoning", "") + print(f"First request cached: {cached1}") + print(f"Has reasoning: {len(reasoning1) > 0}") + + # Second request extending the conversation + messages2 = messages1 + [ + {"role": "assistant", "content": content1[:50] if content1 else "10"}, + {"role": "user", "content": "Now what about 6+6?"} + ] + response2 = make_request(messages2, max_tokens=100, enable_thinking=True) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Second request cached: {cached2}") + + assert cached1 == 0, "First request should have 0 cached tokens" + assert cached2 > 0, f"Second request should use cache, but cached_tokens={cached2}" + + print("PASSED") + return True + + +def test_interleaved_conversations(): + """Test multiple interleaved conversations don't interfere.""" + print("\n=== Test: Interleaved Conversations ===") + + import time + unique_id = str(int(time.time() * 1000)) + + # Conversation A + messages_a = [ + {"role": "system", "content": f"You are a math bot. [{unique_id}A]"}, + {"role": "user", "content": "Remember the number 100."} + ] + response_a1 = make_request(messages_a, enable_thinking=False) + if "error" in response_a1: + print(f"FAILED (conv A1): {response_a1['error']}") + return False + content_a1 = response_a1.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"Conv A1: Got response") + + # Conversation B (different) + messages_b = [ + {"role": "system", "content": f"You are a math bot. [{unique_id}B]"}, + {"role": "user", "content": "Remember the number 200."} + ] + response_b1 = make_request(messages_b, enable_thinking=False) + if "error" in response_b1: + print(f"FAILED (conv B1): {response_b1['error']}") + return False + content_b1 = response_b1.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"Conv B1: Got response") + + # Extend conversation A + messages_a2 = messages_a + [ + {"role": "assistant", "content": content_a1[:50]}, + {"role": "user", "content": "What number did I tell you?"} + ] + response_a2 = make_request(messages_a2, enable_thinking=False) + if "error" in response_a2: + print(f"FAILED (conv A2): {response_a2['error']}") + return False + cached_a2 = response_a2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Conv A2: cached={cached_a2}") + + # Extend conversation B + messages_b2 = messages_b + [ + {"role": "assistant", "content": content_b1[:50]}, + {"role": "user", "content": "What number did I tell you?"} + ] + response_b2 = make_request(messages_b2, enable_thinking=False) + if "error" in response_b2: + print(f"FAILED (conv B2): {response_b2['error']}") + return False + cached_b2 = response_b2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Conv B2: cached={cached_b2}") + + # Both conversations should use their own cache + assert cached_a2 > 0, f"Conv A2 should use cache, but cached_tokens={cached_a2}" + assert cached_b2 > 0, f"Conv B2 should use cache, but cached_tokens={cached_b2}" + + print("PASSED") + return True + + +def test_long_conversation_chain(): + """Test caching with a long conversation chain.""" + print("\n=== Test: Long Conversation Chain ===") + + import time + unique_id = str(int(time.time() * 1000)) + + messages = [ + {"role": "system", "content": f"You are a helpful assistant. Be very brief. [{unique_id}]"} + ] + + cached_tokens_history = [] + + # Build up a conversation with 10 turns + for i in range(10): + messages.append({"role": "user", "content": f"Turn {i+1}: Say 'ok'."}) + response = make_request(messages, max_tokens=5, enable_thinking=False) + if "error" in response: + print(f"FAILED (turn {i+1}): {response['error']}") + return False + + cached = response.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + cached_tokens_history.append(cached) + + content = response.get("choices", [{}])[0].get("message", {}).get("content", "") + messages.append({"role": "assistant", "content": content[:20]}) + + if i < 3 or i >= 8: # Print first few and last few + print(f"Turn {i+1}: cached={cached}") + + # Verify cache grows with each turn + assert cached_tokens_history[0] == 0, "First request should have 0 cached tokens" + + # Check that cache generally increases (with some tolerance for slight variations) + for i in range(1, len(cached_tokens_history)): + # Allow small variations due to tokenization differences + assert cached_tokens_history[i] >= cached_tokens_history[i-1] - 5, \ + f"Turn {i+1} cache ({cached_tokens_history[i]}) should not be much less than turn {i} ({cached_tokens_history[i-1]})" + + print(f"Final cached tokens: {cached_tokens_history[-1]}") + print("PASSED") + return True + + +def test_cache_with_different_params(): + """Test that different generation params still use same cache for prompt.""" + print("\n=== Test: Cache with Different Generation Params ===") + + import time + unique_id = str(int(time.time() * 1000)) + + messages = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + {"role": "user", "content": "Count to 3."} + ] + + # First request with temp 0 + data1 = { + "model": MODEL_PATH, + "messages": messages, + "max_tokens": 10, + "temperature": 0.0, + "chat_template_kwargs": {"enable_thinking": False} + } + req1 = urllib.request.Request( + f"http://{SERVER_HOST}:{SERVER_PORT}/v1/chat/completions", + data=json.dumps(data1).encode(), + headers={"Content-Type": "application/json"}, + method="POST" + ) + response1 = json.loads(urllib.request.urlopen(req1, timeout=REQUEST_TIMEOUT).read().decode()) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + cached1 = response1.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Request 1 (temp=0.0): cached={cached1}") + + # Second request with temp 1.0 (different generation param) + data2 = { + "model": MODEL_PATH, + "messages": messages, + "max_tokens": 10, + "temperature": 1.0, + "chat_template_kwargs": {"enable_thinking": False} + } + req2 = urllib.request.Request( + f"http://{SERVER_HOST}:{SERVER_PORT}/v1/chat/completions", + data=json.dumps(data2).encode(), + headers={"Content-Type": "application/json"}, + method="POST" + ) + response2 = json.loads(urllib.request.urlopen(req2, timeout=REQUEST_TIMEOUT).read().decode()) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + prompt_tokens2 = response2.get("usage", {}).get("prompt_tokens", 0) + print(f"Request 2 (temp=1.0): cached={cached2}, prompt_tokens={prompt_tokens2}") + + # Second request should get exact match (all tokens cached) despite different temperature + assert cached1 == 0, "First request should have 0 cached tokens" + assert cached2 == prompt_tokens2, f"Second request should have all tokens cached (exact match), got {cached2}/{prompt_tokens2}" + + print("PASSED") + return True + + +def test_partial_cache_extension(): + """Test that partial cache matches extend correctly.""" + print("\n=== Test: Partial Cache Extension ===") + + import time + unique_id = str(int(time.time() * 1000)) + + # First, establish a base conversation + messages1 = [ + {"role": "system", "content": f"You are a helpful assistant. [{unique_id}]"}, + {"role": "user", "content": "My name is Test."} + ] + response1 = make_request(messages1, enable_thinking=False) + if "error" in response1: + print(f"FAILED (request 1): {response1['error']}") + return False + content1 = response1.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"Request 1: Base conversation established") + + # Extend the conversation + messages2 = messages1 + [ + {"role": "assistant", "content": content1[:30]}, + {"role": "user", "content": "What is my name?"} + ] + response2 = make_request(messages2, enable_thinking=False) + if "error" in response2: + print(f"FAILED (request 2): {response2['error']}") + return False + cached2 = response2.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + content2 = response2.get("choices", [{}])[0].get("message", {}).get("content", "") + print(f"Request 2: cached={cached2}") + + # Extend further + messages3 = messages2 + [ + {"role": "assistant", "content": content2[:30]}, + {"role": "user", "content": "Can you spell it?"} + ] + response3 = make_request(messages3, enable_thinking=False) + if "error" in response3: + print(f"FAILED (request 3): {response3['error']}") + return False + cached3 = response3.get("usage", {}).get("prompt_tokens_details", {}).get("cached_tokens", 0) + print(f"Request 3: cached={cached3}") + + # Verify cache increases with each extension + assert cached2 > 0, f"Request 2 should use cache, got {cached2}" + assert cached3 > cached2, f"Request 3 ({cached3}) should have more cache than request 2 ({cached2})" + + print("PASSED") + return True + + +def run_tests(): + """Run all tests.""" + print("=" * 60) + print("Hybrid Model Caching Tests") + print("=" * 60) + + # Start server + print("\nStarting server...") + server_proc = start_server() + print("Server started successfully!") + + try: + results = {} + + # Run tests + results["basic_generation"] = test_basic_generation() + results["prompt_chaining"] = test_prompt_chaining() + results["extended_prompt_chaining"] = test_extended_prompt_chaining() + results["cache_persistence"] = test_cache_persistence() + results["cache_invalidation"] = test_cache_invalidation() + results["conversation_branching"] = test_conversation_branching() + results["cache_survives_multiple_branches"] = test_cache_survives_multiple_branches() + results["exact_match_reuse"] = test_exact_match_reuse() + results["thinking_mode_caching"] = test_thinking_mode_caching() + results["interleaved_conversations"] = test_interleaved_conversations() + results["long_conversation_chain"] = test_long_conversation_chain() + results["cache_with_different_params"] = test_cache_with_different_params() + results["partial_cache_extension"] = test_partial_cache_extension() + + # Summary + print("\n" + "=" * 60) + print("Test Results Summary") + print("=" * 60) + + all_passed = True + for test_name, passed in results.items(): + status = "PASSED" if passed else "FAILED" + print(f" {test_name}: {status}") + if not passed: + all_passed = False + + print() + if all_passed: + print("All tests PASSED!") + return 0 + else: + print("Some tests FAILED!") + return 1 + + finally: + print("\nStopping server...") + stop_server(server_proc) + + +if __name__ == "__main__": + sys.exit(run_tests())