1- from typing import Any , Dict , List , Optional , Union
1+ from typing import Any , Dict , List , Optional
22
33from redis import Redis
44
1010)
1111from redisvl .index import SearchIndex
1212from redisvl .query import RangeQuery
13- from redisvl .query .filter import FilterExpression , Tag
14- from redisvl .redis .utils import array_to_buffer
15- from redisvl .utils .utils import (
16- current_timestamp ,
17- deserialize ,
18- serialize ,
19- validate_vector_dims ,
20- )
13+ from redisvl .query .filter import FilterExpression
14+ from redisvl .utils .utils import current_timestamp , serialize , validate_vector_dims
2115from redisvl .utils .vectorize import BaseVectorizer , HFTextVectorizer
2216
2317
2418class SemanticCache (BaseLLMCache ):
2519 """Semantic Cache for Large Language Models."""
2620
21+ redis_key_field_name : str = "key"
2722 entry_id_field_name : str = "entry_id"
2823 prompt_field_name : str = "prompt"
2924 response_field_name : str = "response"
@@ -55,6 +50,8 @@ def __init__(
5550 in Redis. Defaults to None.
5651 vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache.
5752 Defaults to HFTextVectorizer.
53+ filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields
54+ that can be used to customize cache retrieval with filters.
5855 redis_client(Optional[Redis], optional): A redis client connection instance.
5956 Defaults to None.
6057 redis_url (str, optional): The redis url. Defaults to redis://localhost:6379.
@@ -81,9 +78,6 @@ def __init__(
8178 model = "sentence-transformers/all-mpnet-base-v2"
8279 )
8380
84- # Create semantic cache schema
85- schema = SemanticCacheIndexSchema .from_params (name , prefix , vectorizer .dims )
86-
8781 # Process fields
8882 self .return_fields = [
8983 self .entry_id_field_name ,
@@ -94,18 +88,9 @@ def __init__(
9488 self .metadata_field_name ,
9589 ]
9690
97- if filterable_fields is not None :
98- for filter_field in filterable_fields :
99- if (
100- filter_field ["name" ] in self .return_fields
101- or filter_field ["name" ] == "key"
102- ):
103- raise ValueError (
104- f'{ filter_field ["name" ]} is a reserved field name for the semantic cache schema'
105- )
106- schema .add_field (filter_field )
107- # Add to return fields too
108- self .return_fields .append (filter_field ["name" ])
91+ # Create semantic cache schema and index
92+ schema = SemanticCacheIndexSchema .from_params (name , prefix , vectorizer .dims )
93+ schema = self ._modify_schema (schema , filterable_fields )
10994
11095 self ._index = SearchIndex (schema = schema )
11196
@@ -120,6 +105,30 @@ def __init__(
120105 self .set_threshold (distance_threshold )
121106 self ._index .create (overwrite = False )
122107
108+ def _modify_schema (
109+ self ,
110+ schema : SemanticCacheIndexSchema ,
111+ filterable_fields : Optional [List [Dict [str , Any ]]] = None ,
112+ ) -> SemanticCacheIndexSchema :
113+ """Modify the base cache schema using the provided filterable fields"""
114+
115+ if filterable_fields is not None :
116+ protected_field_names = set (
117+ self .return_fields + [self .redis_key_field_name ]
118+ )
119+ for filter_field in filterable_fields :
120+ field_name = filter_field ["name" ]
121+ if field_name in protected_field_names :
122+ raise ValueError (
123+ f"{ field_name } is a reserved field name for the semantic cache schema"
124+ )
125+ # Add to schema
126+ schema .add_field (filter_field )
127+ # Add to return fields too
128+ self .return_fields .append (field_name )
129+
130+ return schema
131+
123132 @property
124133 def index (self ) -> SearchIndex :
125134 """The underlying SearchIndex for the cache.
0 commit comments