From 51b6af3b203d3a40a9ed71cc82f4e4c25cff0979 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Wed, 31 Jul 2024 13:54:15 -0400 Subject: [PATCH 1/4] Use pydantic for cache entries and hits --- docs/user_guide/llmcache_03.ipynb | 35 +-- redisvl/extensions/llmcache/base.py | 1 - redisvl/extensions/llmcache/schema.py | 106 +++++++++ redisvl/extensions/llmcache/semantic.py | 294 ++++++++++-------------- redisvl/index/index.py | 15 +- tests/integration/test_llmcache.py | 107 ++++++--- tests/unit/test_llmcache_schema.py | 119 ++++++++++ 7 files changed, 453 insertions(+), 224 deletions(-) create mode 100644 redisvl/extensions/llmcache/schema.py create mode 100644 tests/unit/test_llmcache_schema.py diff --git a/docs/user_guide/llmcache_03.ipynb b/docs/user_guide/llmcache_03.ipynb index 8d326f8f..3403b287 100644 --- a/docs/user_guide/llmcache_03.ipynb +++ b/docs/user_guide/llmcache_03.ipynb @@ -83,7 +83,6 @@ "\n", "llmcache = SemanticCache(\n", " name=\"llmcache\", # underlying search index name\n", - " prefix=\"llmcache\", # redis key prefix for hash entries\n", " redis_url=\"redis://localhost:6379\", # redis connection url string\n", " distance_threshold=0.1 # semantic cache distance threshold\n", ")" @@ -107,13 +106,15 @@ "│ llmcache │ HASH │ ['llmcache'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴──────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", - "╭───────────────┬───────────────┬────────┬────────────────┬────────────────╮\n", - "│ Name │ Attribute │ Type │ Field Option │ Option Value │\n", - "├───────────────┼───────────────┼────────┼────────────────┼────────────────┤\n", - "│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │\n", - "│ response │ response │ TEXT │ WEIGHT │ 1 │\n", - "│ prompt_vector │ prompt_vector │ VECTOR │ │ │\n", - "╰───────────────┴───────────────┴────────┴────────────────┴────────────────╯\n" + "╭───────────────┬───────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├───────────────┼───────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ prompt │ prompt │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ response │ response │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ inserted_at │ inserted_at │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ updated_at │ updated_at │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ prompt_vector │ prompt_vector │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 768 │ distance_metric │ COSINE │\n", + "╰───────────────┴───────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], @@ -208,7 +209,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[{'id': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545', 'vector_distance': '9.53674316406e-07', 'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}}]\n" + "[{'prompt': 'What is the capital of France?', 'response': 'Paris', 'metadata': {'city': 'Paris', 'country': 'france'}, 'key': 'llmcache:115049a298532be2f181edb03f766770c0db84c22aff39003fec340deaec7545'}]\n" ] } ], @@ -384,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -408,14 +409,14 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Without caching, a call to openAI to answer this simple question took 1.460299015045166 seconds.\n" + "Without caching, a call to openAI to answer this simple question took 0.9312698841094971 seconds.\n" ] }, { @@ -424,7 +425,7 @@ "'llmcache:67e0f6e28fe2a61c0022fd42bf734bb8ffe49d3e375fd69d692574295a20fc1a'" ] }, - "execution_count": 18, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } @@ -451,8 +452,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "Avg time taken with LLM cache enabled: 0.2560166358947754\n", - "Percentage of time saved: 82.47%\n" + "Avg time taken with LLM cache enabled: 0.4896167993545532\n", + "Percentage of time saved: 47.42%\n" ] } ], @@ -515,7 +516,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -540,7 +541,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.14" }, "orig_nbformat": 4 }, diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index a1c88466..d11a404f 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,4 +1,3 @@ -import json from typing import Any, Dict, List, Optional from redisvl.redis.utils import hashify diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py new file mode 100644 index 00000000..64ddeb21 --- /dev/null +++ b/redisvl/extensions/llmcache/schema.py @@ -0,0 +1,106 @@ +from typing import Any, Dict, List, Optional +from pydantic.v1 import BaseModel, Field, root_validator, validator +from redisvl.redis.utils import array_to_buffer, hashify +from redisvl.utils.utils import current_timestamp, deserialize, serialize +from redisvl.schema import IndexSchema + + +class CacheEntry(BaseModel): + entry_id: str + prompt: str + response: str + prompt_vector: List[float] + inserted_at: float = Field(default_factory=current_timestamp) + updated_at: float = Field(default_factory=current_timestamp) + metadata: Optional[Dict[str, Any]] = Field(default=None) + filters: Optional[Dict[str, Any]] = Field(default=None) + + @root_validator(pre=True) + @classmethod + def generate_id(cls, values): + # Ensure entry_id is set + if not values.get("entry_id"): + values["entry_id"] = hashify(values["prompt"]) + return values + + @validator("metadata") + def non_empty_metadata(cls, v): + if v is not None and not isinstance(v, dict): + raise TypeError("Metadata must be a dictionary.") + return v + + def to_dict(self) -> Dict: + data = self.dict(exclude_none=True) + data["prompt_vector"] = array_to_buffer(self.prompt_vector) + if self.metadata: + data["metadata"] = serialize(self.metadata) + if self.filters: + data.update(self.filters) + del data["filters"] + return data + + +class CacheHit(BaseModel): + entry_id: str + prompt: str + response: str + vector_distance: float + inserted_at: float + updated_at: float + metadata: Optional[Dict[str, Any]] = Field(default=None) + filters: Optional[Dict[str, Any]] = Field(default=None) + + @root_validator(pre=True) + @classmethod + def validate_cache_hit(cls, values): + # Deserialize metadata if necessary + if "metadata" in values and isinstance(values["metadata"], str): + values["metadata"] = deserialize(values["metadata"]) + + # Separate filters from other fields + known_fields = set(cls.__fields__.keys()) + filters = {k: v for k, v in values.items() if k not in known_fields} + + # Add filters to values + if filters: + values["filters"] = filters + + # Remove filter fields from the main values + for k in filters: + values.pop(k) + + return values + + def to_dict(self) -> Dict: + data = self.dict(exclude_none=True) + if self.filters: + data.update(self.filters) + del data["filters"] + + return data + + +class SemanticCacheIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vector_dims: int): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": "prompt", "type": "text"}, + {"name": "response", "type": "text"}, + {"name": "inserted_at", "type": "numeric"}, + {"name": "updated_at", "type": "numeric"}, + { + "name": "prompt_vector", + "type": "vector", + "attrs": { + "dims": vector_dims, + "datatype": "float32", + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) \ No newline at end of file diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index b17c18c9..f0b5b66b 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -3,61 +3,38 @@ from redis import Redis from redisvl.extensions.llmcache.base import BaseLLMCache +from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit, SemanticCacheIndexSchema from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression, Tag from redisvl.redis.utils import array_to_buffer -from redisvl.schema import IndexSchema -from redisvl.utils.utils import current_timestamp, deserialize, serialize +from redisvl.utils.utils import ( + current_timestamp, + deserialize, + serialize, + validate_vector_dims, +) from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer -class SemanticCacheIndexSchema(IndexSchema): - - @classmethod - def from_params(cls, name: str, vector_dims: int): - - return cls( - index={"name": name, "prefix": name}, # type: ignore - fields=[ # type: ignore - {"name": "prompt", "type": "text"}, - {"name": "response", "type": "text"}, - {"name": "inserted_at", "type": "numeric"}, - {"name": "updated_at", "type": "numeric"}, - {"name": "label", "type": "tag"}, - { - "name": "prompt_vector", - "type": "vector", - "attrs": { - "dims": vector_dims, - "datatype": "float32", - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) - - class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" - entry_id_field_name: str = "_id" + entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" + response_field_name: str = "response" vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" - tag_field_name: str = "label" - response_field_name: str = "response" metadata_field_name: str = "metadata" def __init__( self, name: str = "llmcache", - prefix: Optional[str] = None, distance_threshold: float = 0.1, ttl: Optional[int] = None, vectorizer: Optional[BaseVectorizer] = None, + filterable_fields: Optional[List[Dict[str, Any]]] = None, redis_client: Optional[Redis] = None, redis_url: str = "redis://localhost:6379", connection_kwargs: Dict[str, Any] = {}, @@ -68,9 +45,6 @@ def __init__( Args: name (str, optional): The name of the semantic cache search index. Defaults to "llmcache". - prefix (Optional[str], optional): The prefix for Redis keys - associated with the semantic cache search index. Defaults to - None, and the index name will be used as the key prefix. distance_threshold (float, optional): Semantic threshold for the cache. Defaults to 0.1. ttl (Optional[int], optional): The time-to-live for records cached @@ -92,7 +66,9 @@ def __init__( super().__init__(ttl) # Use the index name as the key prefix by default - if prefix is None: + if "prefix" in kwargs: + prefix = kwargs["prefix"] + else: prefix = name # Set vectorizer default @@ -101,25 +77,37 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - schema = SemanticCacheIndexSchema.from_params(name, vectorizer.dims) + # Create semantic cache schema + schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) + + # Process fields + self.return_fields = [ + self.entry_id_field_name, + self.prompt_field_name, + self.response_field_name, + self.inserted_at_field_name, + self.updated_at_field_name, + self.metadata_field_name, + ] + + if filterable_fields is not None: + for filter_field in filterable_fields: + if filter_field["name"] in self.return_fields or filter_field["name"] =="key": + raise ValueError(f'{filter_field["name"]} is a reserved field name for the semantic cache schema') + schema.add_field(filter_field) + # Add to return fields too + self.return_fields.append(filter_field["name"]) + self._index = SearchIndex(schema=schema) - # handle redis connection + # Handle redis connection if redis_client: self._index.set_client(redis_client) elif redis_url: self._index.connect(redis_url=redis_url, **connection_kwargs) - # initialize other components - self.default_return_fields = [ - self.entry_id_field_name, - self.prompt_field_name, - self.response_field_name, - self.tag_field_name, - self.vector_field_name, - self.metadata_field_name, - ] - self.set_vectorizer(vectorizer) + # Initialize other components + self._set_vectorizer(vectorizer) self.set_threshold(distance_threshold) self._index.create(overwrite=False) @@ -157,7 +145,7 @@ def set_threshold(self, distance_threshold: float) -> None: ) self._distance_threshold = float(distance_threshold) - def set_vectorizer(self, vectorizer: BaseVectorizer) -> None: + def _set_vectorizer(self, vectorizer: BaseVectorizer) -> None: """Sets the vectorizer for the LLM cache. Must be a valid subclass of BaseVectorizer and have equivalent @@ -175,14 +163,7 @@ def set_vectorizer(self, vectorizer: BaseVectorizer) -> None: raise TypeError("Must provide a valid redisvl.vectorizer class.") schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore - - if schema_vector_dims != vectorizer.dims: - raise ValueError( - "Invalid vector dimensions! " - f"Vectorizer has dims defined as {vectorizer.dims}", - f"Vector field has dims defined as {schema_vector_dims}", - ) - + validate_vector_dims(vectorizer.dims, schema_vector_dims) self._vectorizer = vectorizer def clear(self) -> None: @@ -194,13 +175,20 @@ def delete(self) -> None: index.""" self._index.delete(drop=True) - def drop(self, document_ids: Union[str, List[str]]) -> None: - """Remove a specific entry or entries from the cache by it's ID. + def drop( + self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None + ) -> None: + """Manually expire specific entries from the cache by id or specific + Redis key. Args: - document_ids (Union[str, List[str]]): The document ID or IDs to remove from the cache. + ids (Optional[str]): The document ID or IDs to remove from the cache. + keys (Optional[str]): """ - self._index.drop_keys(document_ids) + if ids is not None: + self._index.drop_keys([self._index.key(id) for id in ids]) + if keys is not None: + self._index.drop_keys(keys) def _refresh_ttl(self, key: str) -> None: """Refresh the time-to-live for the specified key.""" @@ -212,61 +200,14 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: configured vectorizer.""" if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - return self._vectorizer.embed(prompt) - - def _search_cache( - self, - vector: List[float], - num_results: int, - return_fields: Optional[List[str]], - tag_filter: Optional[FilterExpression], - ) -> List[Dict[str, Any]]: - """Searches the semantic cache for similar prompt vectors and returns - the specified return fields for each cache hit.""" - # Setup and type checks - if not isinstance(vector, list): - raise TypeError("Vector must be a list of floats") - return_fields = return_fields or self.default_return_fields - - if not isinstance(return_fields, list): - raise TypeError("return_fields must be a list of field names") - - # Construct vector RangeQuery for the cache check - query = RangeQuery( - vector=vector, - vector_field_name=self.vector_field_name, - return_fields=return_fields, - distance_threshold=self._distance_threshold, - num_results=num_results, - return_score=True, - ) - if tag_filter: - query.set_filter(tag_filter) # type: ignore - - # Gather and return the cache hits - cache_hits: List[Dict[str, Any]] = self._index.query(query) - # Process cache hits - for hit in cache_hits: - key = hit["id"] - self._refresh_ttl(key) - # Check for metadata and deserialize - if self.metadata_field_name in hit: - hit[self.metadata_field_name] = deserialize( - hit[self.metadata_field_name] - ) - return cache_hits + return self._vectorizer.embed(prompt) def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore - if schema_vector_dims != len(vector): - raise ValueError( - "Invalid vector dimensions! " - f"Vector has dims defined as {len(vector)}", - f"Vector field has dims defined as {schema_vector_dims}", - ) + validate_vector_dims(len(vector), schema_vector_dims) def check( self, @@ -274,7 +215,7 @@ def check( vector: Optional[List[float]] = None, num_results: int = 1, return_fields: Optional[List[str]] = None, - tag_filter: Optional[FilterExpression] = None, + filter_expression: Optional[FilterExpression] = None, ) -> List[Dict[str, Any]]: """Checks the semantic cache for results similar to the specified prompt or vector. @@ -294,8 +235,9 @@ def check( return_fields (Optional[List[str]], optional): The fields to include in each returned result. If None, defaults to all available fields in the cached entry. - tag_filter (Optional[FilterExpression]) : the tag filter to filter - results by. Default is None and full cache is searched. + filter_expression (Optional[FilterExpression]) : Optional filter expression + that can be used to filter cache results. Defaults to None and + the full cache will be searched. Returns: List[Dict[str, Any]]: A list of dicts containing the requested @@ -315,12 +257,38 @@ def check( if not (prompt or vector): raise ValueError("Either prompt or vector must be specified.") - # Use provided vector or create from prompt vector = vector or self._vectorize_prompt(prompt) self._check_vector_dims(vector) + return_fields = return_fields or self.return_fields + + if not isinstance(return_fields, list): + raise TypeError("return_fields must be a list of field names") + + query = RangeQuery( + vector=vector, + vector_field_name=self.vector_field_name, + return_fields=self.return_fields, + distance_threshold=self._distance_threshold, + num_results=num_results, + return_score=True, + filter_expression=filter_expression, + ) + + cache_hits: List[Dict[Any, str]] = [] + + # Search the cache! + cache_search_results = self._index.query(query) + + for cache_search_result in cache_search_results: + key = cache_search_result["id"] + self._refresh_ttl(key) + + # Create cache hit + cache_hit = CacheHit(**cache_search_result) + cache_hit_dict = {k: v for k, v in cache_hit.to_dict().items() if k in return_fields} + cache_hit_dict["key"] = key + cache_hits.append(cache_hit_dict) - # Check for cache hits by searching the cache - cache_hits = self._search_cache(vector, num_results, return_fields, tag_filter) return cache_hits def store( @@ -328,8 +296,8 @@ def store( prompt: str, response: str, vector: Optional[List[float]] = None, - metadata: Optional[dict] = None, - tag: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, ) -> str: """Stores the specified key-value pair in the cache along with metadata. @@ -339,9 +307,9 @@ def store( vector (Optional[List[float]], optional): The prompt vector to cache. Defaults to None, and the prompt vector is generated on demand. - metadata (Optional[dict], optional): The optional metadata to cache + metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache alongside the prompt and response. Defaults to None. - tag (Optional[str]): The optional tag to assign to the cache entry. + filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. Defaults to None. Returns: @@ -362,29 +330,24 @@ def store( """ # Vectorize prompt if necessary and create cache payload vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) - # Construct semantic cache payload - now = current_timestamp() - id_field = self.entry_id_field_name - payload = { - id_field: self.hash_input(prompt), - self.prompt_field_name: prompt, - self.response_field_name: response, - self.vector_field_name: array_to_buffer(vector), - self.inserted_at_field_name: now, - self.updated_at_field_name: now, - } - if metadata is not None: - if not isinstance(metadata, dict): - raise TypeError("If specified, cached metadata must be a dictionary.") - # Serialize the metadata dict and add to cache payload - payload[self.metadata_field_name] = serialize(metadata) - if tag is not None: - payload[self.tag_field_name] = tag - - # Load LLMCache entry with TTL - keys = self._index.load(data=[payload], ttl=self._ttl, id_field=id_field) + # Build cache entry for the cache + cache_entry = CacheEntry( + prompt=prompt, + response=response, + prompt_vector=vector, + metadata=metadata, + filters=filters, + ) + + # Load cache entry with TTL + keys = self._index.load( + data=[cache_entry.to_dict()], + ttl=self._ttl, + id_field=self.entry_id_field_name, + ) return keys[0] def update(self, key: str, **kwargs) -> None: @@ -392,8 +355,7 @@ def update(self, key: str, **kwargs) -> None: are passed, then only the document TTL is refreshed. Args: - key (str): the key of the document to update. - kwargs: + key (str): the key of the document to update using kwargs. Raises: ValueError if an incorrect mapping is provided as a kwarg. @@ -404,28 +366,24 @@ def update(self, key: str, **kwargs) -> None: cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) ) """ - if not kwargs: - self._refresh_ttl(key) - return - - for _key, val in kwargs.items(): - if _key not in { - self.prompt_field_name, - self.vector_field_name, - self.response_field_name, - self.tag_field_name, - self.metadata_field_name, - }: - raise ValueError(f" {key} is not a valid field within document") - - # Check for metadata and deserialize - if _key == self.metadata_field_name: - if isinstance(val, dict): - kwargs[_key] = serialize(val) - else: - raise TypeError( - "If specified, cached metadata must be a dictionary." - ) - kwargs.update({self.updated_at_field_name: current_timestamp()}) - self._index.client.hset(key, mapping=kwargs) # type: ignore + if kwargs: + for k, v in kwargs.items(): + + # Make sure the item is in the index schema + if k not in set(self._index.schema.field_names + [self.metadata_field_name]): + raise ValueError(f"{k} is not a valid field within the cache entry") + + # Check for metadata and deserialize + if k == self.metadata_field_name: + if isinstance(v, dict): + kwargs[k] = serialize(v) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + + kwargs.update({self.updated_at_field_name: current_timestamp()}) + + self._index.client.hset(key, mapping=kwargs) # type: ignore + self._refresh_ttl(key) diff --git a/redisvl/index/index.py b/redisvl/index/index.py index 6a01c1ce..f5e6b4a6 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -193,11 +193,6 @@ def storage_type(self) -> StorageType: hash or json.""" return self.schema.index.storage_type - @property - def client(self) -> Optional[Union[redis.Redis, aredis.Redis]]: - """The underlying redis-py client object.""" - return self._redis_client - @classmethod def from_yaml(cls, schema_path: str, **kwargs): """Create a SearchIndex from a YAML schema file. @@ -364,6 +359,11 @@ def from_existing( schema = IndexSchema.from_dict(schema_dict) return cls(schema, redis_client, **kwargs) + @property + def client(self) -> Optional[redis.Redis]: + """The underlying redis-py client object.""" + return self._redis_client + def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). @@ -843,6 +843,11 @@ async def from_existing( await index.set_client(redis_client) return index + @property + def client(self) -> Optional[aredis.Redis]: + """The underlying redis-py client object.""" + return self._redis_client + async def connect(self, redis_url: Optional[str] = None, **kwargs): """Connect to a Redis instance using the provided `redis_url`, falling back to the `REDIS_URL` environment variable (if available). diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index b272ac30..ec1c7c15 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,7 +1,8 @@ from collections import namedtuple from time import sleep, time - +from pydantic.v1 import ValidationError import pytest + from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache @@ -23,6 +24,17 @@ def cache(vectorizer, redis_url): yield cache_instance cache_instance._index.delete(True) # Clean up index +@pytest.fixture +def cache_with_filters(vectorizer, redis_url): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + filterable_fields=[{"name": "label", "type": "tag"}], + redis_url=redis_url + ) + yield cache_instance + cache_instance._index.delete(True) # Clean up index + @pytest.fixture def cache_no_cleanup(vectorizer, redis_url): @@ -100,32 +112,30 @@ def test_return_fields(cache, vectorizer): # check default return fields check_result = cache.check(vector=vector) assert set(check_result[0].keys()) == { - "id", - "_id", + "key", + "entry_id", "prompt", "response", - "prompt_vector", "vector_distance", + "inserted_at", + "updated_at", } - # check all return fields + # check specific return fields fields = [ - "id", - "_id", + "key", + "entry_id", "prompt", "response", - "inserted_at", - "updated_at", - "prompt_vector", "vector_distance", ] - check_result = cache.check(vector=vector, return_fields=fields[:]) + check_result = cache.check(vector=vector, return_fields=fields) assert set(check_result[0].keys()) == set(fields) # check only some return fields fields = ["inserted_at", "updated_at"] - check_result = cache.check(vector=vector, return_fields=fields[:]) - fields.extend(["id", "vector_distance"]) # id and vector_distance always returned + check_result = cache.check(vector=vector, return_fields=fields) + fields.append("key") assert set(check_result[0].keys()) == set(fields) @@ -178,7 +188,7 @@ def test_drop_document(cache, vectorizer): cache.store(prompt, response, vector=vector) check_result = cache.check(vector=vector) - cache.drop(check_result[0]["id"]) + cache.drop(ids=[check_result[0]["entry_id"]]) recheck_result = cache.check(vector=vector) assert len(recheck_result) == 0 @@ -200,8 +210,9 @@ def test_drop_documents(cache, vectorizer): cache.store(prompt, response, vector=vector) check_result = cache.check(vector=vector, num_results=3) - keys = [r["id"] for r in check_result[0:2]] # drop first 2 entries - cache.drop(keys) + print(check_result, flush=True) + ids = [r["entry_id"] for r in check_result[0:2]] # drop first 2 entries + cache.drop(ids=ids) recheck_result = cache.check(vector=vector, num_results=3) assert len(recheck_result) == 1 @@ -214,7 +225,7 @@ def test_updating_document(cache): cache.store(prompt=prompt, response=response) check_result = cache.check(prompt=prompt, return_fields=["updated_at"]) - key = check_result[0]["id"] + key = check_result[0]["key"] sleep(1) @@ -290,9 +301,7 @@ def test_store_with_invalid_metadata(cache, vectorizer): vector = vectorizer.embed(prompt) - with pytest.raises( - TypeError, match=r"If specified, cached metadata must be a dictionary." - ): + with pytest.raises(ValidationError): cache.store(prompt, response, vector=vector, metadata=metadata) @@ -381,8 +390,11 @@ def test_vector_size(cache, vectorizer): cache.check(vector=[1, 2, 3]) -# test we can pass a list of tags and we'll include all results that match -def test_multiple_tags(cache): +def test_cache_with_filters(cache_with_filters): + assert "label" in cache_with_filters._index.schema.fields + + +def test_cache_filtering(cache_with_filters): tag_1 = "group 0" tag_2 = "group 1" tag_3 = "group 2" @@ -396,43 +408,72 @@ def test_multiple_tags(cache): for i in range(4): prompt = f"test prompt {i}" response = f"test response {i}" - cache.store(prompt, response, tag=tags[i]) + cache_with_filters.store(prompt, response, filters={"label": tags[i]}) # test we can specify one specific tag - results = cache.check("test prompt 1", tag_filter=filter_1, num_results=5) + results = cache_with_filters.check("test prompt 1", filter_expression=filter_1, num_results=5) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags combined_filter = filter_1 | filter_2 | filter_3 - results = cache.check("test prompt 1", tag_filter=combined_filter, num_results=5) + results = cache_with_filters.check("test prompt 1", filter_expression=combined_filter, num_results=5) assert len(results) == 3 # test that default tag param searches full cache - results = cache.check("test prompt 1", num_results=5) + results = cache_with_filters.check("test prompt 1", num_results=5) assert len(results) == 4 # test no results are returned if we pass a nonexistant tag bad_filter = Tag("label") == "bad tag" - results = cache.check("test prompt 1", tag_filter=bad_filter, num_results=5) + results = cache_with_filters.check("test prompt 1", filter_expression=bad_filter, num_results=5) assert len(results) == 0 -def test_complex_filters(cache): - cache.store(prompt="prompt 1", response="response 1") - cache.store(prompt="prompt 2", response="response 2") +def test_cache_bad_filters(vectorizer, redis_url): + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # invalid field type + filterable_fields=[{"name": "label", "type": "tag"}, {"name": "test", "type": "nothing"}], + redis_url=redis_url + ) + + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # duplicate field type + filterable_fields=[{"name": "label", "type": "tag"}, {"name": "label", "type": "tag"}], + redis_url=redis_url + ) + + with pytest.raises(ValueError): + cache_instance = SemanticCache( + vectorizer=vectorizer, + distance_threshold=0.2, + # reserved field name + filterable_fields=[{"name": "label", "type": "tag"}, {"name": "metadata", "type": "tag"}], + redis_url=redis_url + ) + + +def test_complex_filters(cache_with_filters): + cache_with_filters.store(prompt="prompt 1", response="response 1") + cache_with_filters.store(prompt="prompt 2", response="response 2") sleep(1) current_timestamp = time() - cache.store(prompt="prompt 3", response="response 3") + cache_with_filters.store(prompt="prompt 3", response="response 3") # test we can do range filters on inserted_at and updated_at fields range_filter = Num("inserted_at") < current_timestamp - results = cache.check("prompt 1", tag_filter=range_filter, num_results=5) + results = cache_with_filters.check("prompt 1", filter_expression=range_filter, num_results=5) assert len(results) == 2 # test we can combine range filters and text filters prompt_filter = Text("prompt") % "*pt 1" combined_filter = prompt_filter & range_filter - results = cache.check("prompt 1", tag_filter=combined_filter, num_results=5) + results = cache_with_filters.check("prompt 1", filter_expression=combined_filter, num_results=5) assert len(results) == 1 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py new file mode 100644 index 00000000..11b8bb58 --- /dev/null +++ b/tests/unit/test_llmcache_schema.py @@ -0,0 +1,119 @@ +import pytest +import json + +from pydantic.v1 import ValidationError +from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit +from redisvl.redis.utils import hashify, array_to_buffer + + +def test_valid_cache_entry_creation(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3] + ) + assert entry.entry_id == hashify("What is AI?") + assert entry.prompt == "What is AI?" + assert entry.response == "AI is artificial intelligence." + assert entry.prompt_vector == [0.1, 0.2, 0.3] + +def test_cache_entry_with_given_entry_id(): + entry = CacheEntry( + entry_id="custom_id", + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3] + ) + assert entry.entry_id == "custom_id" + +def test_cache_entry_with_invalid_metadata(): + with pytest.raises(ValidationError): + CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + metadata="invalid_metadata" + ) + +def test_cache_entry_to_dict(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3], + metadata={"author": "John"}, + filters={"category": "technology"} + ) + result = entry.to_dict() + assert result["entry_id"] == hashify("What is AI?") + assert result["metadata"] == json.dumps({"author": "John"}) + assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3]) + assert result["category"] == "technology" + assert "filters" not in result + +def test_valid_cache_hit_creation(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123 + ) + assert hit.entry_id == "entry_1" + assert hit.prompt == "What is AI?" + assert hit.response == "AI is artificial intelligence." + assert hit.vector_distance == 0.1 + assert hit.inserted_at == hit.updated_at == 1625819123.123 + +def test_cache_hit_with_serialized_metadata(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + metadata=json.dumps({"author": "John"}) + ) + assert hit.metadata == {"author": "John"} + +def test_cache_hit_to_dict(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123, + filters={"category": "technology"} + ) + result = hit.to_dict() + assert result["entry_id"] == "entry_1" + assert result["prompt"] == "What is AI?" + assert result["response"] == "AI is artificial intelligence." + assert result["vector_distance"] == 0.1 + assert result["category"] == "technology" + assert "filters" not in result + +def test_cache_entry_with_empty_optional_fields(): + entry = CacheEntry( + prompt="What is AI?", + response="AI is artificial intelligence.", + prompt_vector=[0.1, 0.2, 0.3] + ) + result = entry.to_dict() + assert "metadata" not in result + assert "filters" not in result + +def test_cache_hit_with_empty_optional_fields(): + hit = CacheHit( + entry_id="entry_1", + prompt="What is AI?", + response="AI is artificial intelligence.", + vector_distance=0.1, + inserted_at=1625819123.123, + updated_at=1625819123.123 + ) + result = hit.to_dict() + assert "metadata" not in result + assert "filters" not in result From 704038b74c0152d13c8e961ef9dcf8ad76854525 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 1 Aug 2024 00:08:42 -0400 Subject: [PATCH 2/4] fix formatting and mypy --- redisvl/extensions/llmcache/schema.py | 8 +++-- redisvl/extensions/llmcache/semantic.py | 25 +++++++++---- tests/integration/test_llmcache.py | 48 +++++++++++++++++-------- tests/unit/test_llmcache_schema.py | 31 ++++++++++------ 4 files changed, 78 insertions(+), 34 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 64ddeb21..60dcdc94 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -1,12 +1,14 @@ from typing import Any, Dict, List, Optional + from pydantic.v1 import BaseModel, Field, root_validator, validator + from redisvl.redis.utils import array_to_buffer, hashify -from redisvl.utils.utils import current_timestamp, deserialize, serialize from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize class CacheEntry(BaseModel): - entry_id: str + entry_id: Optional[str] = Field(default=None) prompt: str response: str prompt_vector: List[float] @@ -103,4 +105,4 @@ def from_params(cls, name: str, prefix: str, vector_dims: int): }, }, ], - ) \ No newline at end of file + ) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index f0b5b66b..b790e6d1 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -3,7 +3,11 @@ from redis import Redis from redisvl.extensions.llmcache.base import BaseLLMCache -from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit, SemanticCacheIndexSchema +from redisvl.extensions.llmcache.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) from redisvl.index import SearchIndex from redisvl.query import RangeQuery from redisvl.query.filter import FilterExpression, Tag @@ -92,8 +96,13 @@ def __init__( if filterable_fields is not None: for filter_field in filterable_fields: - if filter_field["name"] in self.return_fields or filter_field["name"] =="key": - raise ValueError(f'{filter_field["name"]} is a reserved field name for the semantic cache schema') + if ( + filter_field["name"] in self.return_fields + or filter_field["name"] == "key" + ): + raise ValueError( + f'{filter_field["name"]} is a reserved field name for the semantic cache schema' + ) schema.add_field(filter_field) # Add to return fields too self.return_fields.append(filter_field["name"]) @@ -285,7 +294,9 @@ def check( # Create cache hit cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = {k: v for k, v in cache_hit.to_dict().items() if k in return_fields} + cache_hit_dict = { + k: v for k, v in cache_hit.to_dict().items() if k in return_fields + } cache_hit_dict["key"] = key cache_hits.append(cache_hit_dict) @@ -370,7 +381,9 @@ def update(self, key: str, **kwargs) -> None: for k, v in kwargs.items(): # Make sure the item is in the index schema - if k not in set(self._index.schema.field_names + [self.metadata_field_name]): + if k not in set( + self._index.schema.field_names + [self.metadata_field_name] + ): raise ValueError(f"{k} is not a valid field within the cache entry") # Check for metadata and deserialize @@ -384,6 +397,6 @@ def update(self, key: str, **kwargs) -> None: kwargs.update({self.updated_at_field_name: current_timestamp()}) - self._index.client.hset(key, mapping=kwargs) # type: ignore + self._index.client.hset(key, mapping=kwargs) # type: ignore self._refresh_ttl(key) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index ec1c7c15..34c15113 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -1,8 +1,8 @@ from collections import namedtuple from time import sleep, time -from pydantic.v1 import ValidationError -import pytest +import pytest +from pydantic.v1 import ValidationError from redis.exceptions import ConnectionError from redisvl.extensions.llmcache import SemanticCache @@ -24,13 +24,14 @@ def cache(vectorizer, redis_url): yield cache_instance cache_instance._index.delete(True) # Clean up index + @pytest.fixture def cache_with_filters(vectorizer, redis_url): cache_instance = SemanticCache( vectorizer=vectorizer, distance_threshold=0.2, filterable_fields=[{"name": "label", "type": "tag"}], - redis_url=redis_url + redis_url=redis_url, ) yield cache_instance cache_instance._index.delete(True) # Clean up index @@ -411,13 +412,17 @@ def test_cache_filtering(cache_with_filters): cache_with_filters.store(prompt, response, filters={"label": tags[i]}) # test we can specify one specific tag - results = cache_with_filters.check("test prompt 1", filter_expression=filter_1, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=filter_1, num_results=5 + ) assert len(results) == 1 assert results[0]["prompt"] == "test prompt 0" # test we can pass a list of tags combined_filter = filter_1 | filter_2 | filter_3 - results = cache_with_filters.check("test prompt 1", filter_expression=combined_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 3 # test that default tag param searches full cache @@ -426,7 +431,9 @@ def test_cache_filtering(cache_with_filters): # test no results are returned if we pass a nonexistant tag bad_filter = Tag("label") == "bad tag" - results = cache_with_filters.check("test prompt 1", filter_expression=bad_filter, num_results=5) + results = cache_with_filters.check( + "test prompt 1", filter_expression=bad_filter, num_results=5 + ) assert len(results) == 0 @@ -436,8 +443,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # invalid field type - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "test", "type": "nothing"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "test", "type": "nothing"}, + ], + redis_url=redis_url, ) with pytest.raises(ValueError): @@ -445,8 +455,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # duplicate field type - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "label", "type": "tag"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "label", "type": "tag"}, + ], + redis_url=redis_url, ) with pytest.raises(ValueError): @@ -454,8 +467,11 @@ def test_cache_bad_filters(vectorizer, redis_url): vectorizer=vectorizer, distance_threshold=0.2, # reserved field name - filterable_fields=[{"name": "label", "type": "tag"}, {"name": "metadata", "type": "tag"}], - redis_url=redis_url + filterable_fields=[ + {"name": "label", "type": "tag"}, + {"name": "metadata", "type": "tag"}, + ], + redis_url=redis_url, ) @@ -468,12 +484,16 @@ def test_complex_filters(cache_with_filters): # test we can do range filters on inserted_at and updated_at fields range_filter = Num("inserted_at") < current_timestamp - results = cache_with_filters.check("prompt 1", filter_expression=range_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=range_filter, num_results=5 + ) assert len(results) == 2 # test we can combine range filters and text filters prompt_filter = Text("prompt") % "*pt 1" combined_filter = prompt_filter & range_filter - results = cache_with_filters.check("prompt 1", filter_expression=combined_filter, num_results=5) + results = cache_with_filters.check( + "prompt 1", filter_expression=combined_filter, num_results=5 + ) assert len(results) == 1 diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index 11b8bb58..e3961e6b 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -1,47 +1,51 @@ -import pytest import json +import pytest from pydantic.v1 import ValidationError + from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit -from redisvl.redis.utils import hashify, array_to_buffer +from redisvl.redis.utils import array_to_buffer, hashify def test_valid_cache_entry_creation(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) assert entry.entry_id == hashify("What is AI?") assert entry.prompt == "What is AI?" assert entry.response == "AI is artificial intelligence." assert entry.prompt_vector == [0.1, 0.2, 0.3] + def test_cache_entry_with_given_entry_id(): entry = CacheEntry( entry_id="custom_id", prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) assert entry.entry_id == "custom_id" + def test_cache_entry_with_invalid_metadata(): with pytest.raises(ValidationError): CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - metadata="invalid_metadata" + metadata="invalid_metadata", ) + def test_cache_entry_to_dict(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], metadata={"author": "John"}, - filters={"category": "technology"} + filters={"category": "technology"}, ) result = entry.to_dict() assert result["entry_id"] == hashify("What is AI?") @@ -50,6 +54,7 @@ def test_cache_entry_to_dict(): assert result["category"] == "technology" assert "filters" not in result + def test_valid_cache_hit_creation(): hit = CacheHit( entry_id="entry_1", @@ -57,7 +62,7 @@ def test_valid_cache_hit_creation(): response="AI is artificial intelligence.", vector_distance=0.1, inserted_at=1625819123.123, - updated_at=1625819123.123 + updated_at=1625819123.123, ) assert hit.entry_id == "entry_1" assert hit.prompt == "What is AI?" @@ -65,6 +70,7 @@ def test_valid_cache_hit_creation(): assert hit.vector_distance == 0.1 assert hit.inserted_at == hit.updated_at == 1625819123.123 + def test_cache_hit_with_serialized_metadata(): hit = CacheHit( entry_id="entry_1", @@ -73,10 +79,11 @@ def test_cache_hit_with_serialized_metadata(): vector_distance=0.1, inserted_at=1625819123.123, updated_at=1625819123.123, - metadata=json.dumps({"author": "John"}) + metadata=json.dumps({"author": "John"}), ) assert hit.metadata == {"author": "John"} + def test_cache_hit_to_dict(): hit = CacheHit( entry_id="entry_1", @@ -85,7 +92,7 @@ def test_cache_hit_to_dict(): vector_distance=0.1, inserted_at=1625819123.123, updated_at=1625819123.123, - filters={"category": "technology"} + filters={"category": "technology"}, ) result = hit.to_dict() assert result["entry_id"] == "entry_1" @@ -95,16 +102,18 @@ def test_cache_hit_to_dict(): assert result["category"] == "technology" assert "filters" not in result + def test_cache_entry_with_empty_optional_fields(): entry = CacheEntry( prompt="What is AI?", response="AI is artificial intelligence.", - prompt_vector=[0.1, 0.2, 0.3] + prompt_vector=[0.1, 0.2, 0.3], ) result = entry.to_dict() assert "metadata" not in result assert "filters" not in result + def test_cache_hit_with_empty_optional_fields(): hit = CacheHit( entry_id="entry_1", @@ -112,7 +121,7 @@ def test_cache_hit_with_empty_optional_fields(): response="AI is artificial intelligence.", vector_distance=0.1, inserted_at=1625819123.123, - updated_at=1625819123.123 + updated_at=1625819123.123, ) result = hit.to_dict() assert "metadata" not in result From 28c7484032bc40ac069123738b3a7703039d070f Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 1 Aug 2024 12:57:19 -0400 Subject: [PATCH 3/4] documentation cleanup --- redisvl/extensions/llmcache/schema.py | 20 +++++++++ redisvl/extensions/llmcache/semantic.py | 57 ++++++++++++++----------- 2 files changed, 53 insertions(+), 24 deletions(-) diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index 60dcdc94..8075496b 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -8,14 +8,24 @@ class CacheEntry(BaseModel): + """A single cache entry in Redis""" + entry_id: Optional[str] = Field(default=None) + """Cache entry identifier""" prompt: str + """Input prompt or question cached in Redis""" response: str + """Response or answer to the question, cached in Redis""" prompt_vector: List[float] + """Text embedding representation of the prompt""" inserted_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was added to the cache""" updated_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was updated in the cache""" metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" @root_validator(pre=True) @classmethod @@ -43,14 +53,24 @@ def to_dict(self) -> Dict: class CacheHit(BaseModel): + """A cache hit based on some input query""" + entry_id: str + """Cache entry identifier""" prompt: str + """Input prompt or question cached in Redis""" response: str + """Response or answer to the question, cached in Redis""" vector_distance: float + """The semantic distance between the query vector and the stored prompt vector""" inserted_at: float + """Timestamp of when the entry was added to the cache""" updated_at: float + """Timestamp of when the entry was updated in the cache""" metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" @root_validator(pre=True) @classmethod diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index b790e6d1..cc2ce351 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional from redis import Redis @@ -10,20 +10,15 @@ ) from redisvl.index import SearchIndex from redisvl.query import RangeQuery -from redisvl.query.filter import FilterExpression, Tag -from redisvl.redis.utils import array_to_buffer -from redisvl.utils.utils import ( - current_timestamp, - deserialize, - serialize, - validate_vector_dims, -) +from redisvl.query.filter import FilterExpression +from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" + redis_key_field_name: str = "key" entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" response_field_name: str = "response" @@ -55,6 +50,8 @@ def __init__( in Redis. Defaults to None. vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache. Defaults to HFTextVectorizer. + filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields + that can be used to customize cache retrieval with filters. redis_client(Optional[Redis], optional): A redis client connection instance. Defaults to None. redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. @@ -81,9 +78,6 @@ def __init__( model="sentence-transformers/all-mpnet-base-v2" ) - # Create semantic cache schema - schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) - # Process fields self.return_fields = [ self.entry_id_field_name, @@ -94,18 +88,9 @@ def __init__( self.metadata_field_name, ] - if filterable_fields is not None: - for filter_field in filterable_fields: - if ( - filter_field["name"] in self.return_fields - or filter_field["name"] == "key" - ): - raise ValueError( - f'{filter_field["name"]} is a reserved field name for the semantic cache schema' - ) - schema.add_field(filter_field) - # Add to return fields too - self.return_fields.append(filter_field["name"]) + # Create semantic cache schema and index + schema = SemanticCacheIndexSchema.from_params(name, prefix, vectorizer.dims) + schema = self._modify_schema(schema, filterable_fields) self._index = SearchIndex(schema=schema) @@ -120,6 +105,30 @@ def __init__( self.set_threshold(distance_threshold) self._index.create(overwrite=False) + def _modify_schema( + self, + schema: SemanticCacheIndexSchema, + filterable_fields: Optional[List[Dict[str, Any]]] = None, + ) -> SemanticCacheIndexSchema: + """Modify the base cache schema using the provided filterable fields""" + + if filterable_fields is not None: + protected_field_names = set( + self.return_fields + [self.redis_key_field_name] + ) + for filter_field in filterable_fields: + field_name = filter_field["name"] + if field_name in protected_field_names: + raise ValueError( + f"{field_name} is a reserved field name for the semantic cache schema" + ) + # Add to schema + schema.add_field(filter_field) + # Add to return fields too + self.return_fields.append(field_name) + + return schema + @property def index(self) -> SearchIndex: """The underlying SearchIndex for the cache. From f6b874515fe8308c3d32fe2d95fcb7ae081d26c7 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Thu, 1 Aug 2024 15:12:49 -0400 Subject: [PATCH 4/4] remove extra comment --- redisvl/extensions/llmcache/semantic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index cc2ce351..f9518614 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -62,7 +62,6 @@ def __init__( TypeError: If an invalid vectorizer is provided. TypeError: If the TTL value is not an int. ValueError: If the threshold is not between 0 and 1. - ValueError: If the index name is not provided """ super().__init__(ttl)