Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions redisvl/extensions/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
Constants used within the extension classes SemanticCache, BaseSessionManager,
StandardSessionManager,SemanticSessionManager and SemanticRouter.
These constants are also used within theses classes corresponding schema.
"""

# BaseSessionManager
ID_FIELD_NAME: str = "entry_id"
ROLE_FIELD_NAME: str = "role"
CONTENT_FIELD_NAME: str = "content"
TOOL_FIELD_NAME: str = "tool_call_id"
TIMESTAMP_FIELD_NAME: str = "timestamp"
SESSION_FIELD_NAME: str = "session_tag"

# SemanticSessionManager
SESSION_VECTOR_FIELD_NAME: str = "vector_field"

# SemanticCache
REDIS_KEY_FIELD_NAME: str = "key"
ENTRY_ID_FIELD_NAME: str = "entry_id"
PROMPT_FIELD_NAME: str = "prompt"
RESPONSE_FIELD_NAME: str = "response"
CACHE_VECTOR_FIELD_NAME: str = "prompt_vector"
INSERTED_AT_FIELD_NAME: str = "inserted_at"
UPDATED_AT_FIELD_NAME: str = "updated_at"
METADATA_FIELD_NAME: str = "metadata"

# SemanticRouter
ROUTE_VECTOR_FIELD_NAME: str = "vector"
17 changes: 12 additions & 5 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

from pydantic.v1 import BaseModel, Field, root_validator, validator

from redisvl.extensions.constants import (
CACHE_VECTOR_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
PROMPT_FIELD_NAME,
RESPONSE_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
)
from redisvl.redis.utils import array_to_buffer, hashify
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp, deserialize, serialize
Expand Down Expand Up @@ -110,12 +117,12 @@ def from_params(cls, name: str, prefix: str, vector_dims: int):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "prompt", "type": "text"},
{"name": "response", "type": "text"},
{"name": "inserted_at", "type": "numeric"},
{"name": "updated_at", "type": "numeric"},
{"name": PROMPT_FIELD_NAME, "type": "text"},
{"name": RESPONSE_FIELD_NAME, "type": "text"},
{"name": INSERTED_AT_FIELD_NAME, "type": "numeric"},
{"name": UPDATED_AT_FIELD_NAME, "type": "numeric"},
{
"name": "prompt_vector",
"name": CACHE_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"dims": vector_dims,
Expand Down
65 changes: 30 additions & 35 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,16 @@

from redis import Redis

from redisvl.extensions.constants import (
CACHE_VECTOR_FIELD_NAME,
ENTRY_ID_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
METADATA_FIELD_NAME,
PROMPT_FIELD_NAME,
REDIS_KEY_FIELD_NAME,
RESPONSE_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
)
from redisvl.extensions.llmcache.base import BaseLLMCache
from redisvl.extensions.llmcache.schema import (
CacheEntry,
Expand All @@ -19,15 +29,6 @@
class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""

redis_key_field_name: str = "key"
entry_id_field_name: str = "entry_id"
prompt_field_name: str = "prompt"
response_field_name: str = "response"
vector_field_name: str = "prompt_vector"
inserted_at_field_name: str = "inserted_at"
updated_at_field_name: str = "updated_at"
metadata_field_name: str = "metadata"

_index: SearchIndex
_aindex: Optional[AsyncSearchIndex] = None

Expand Down Expand Up @@ -94,12 +95,12 @@ def __init__(
# Process fields and other settings
self.set_threshold(distance_threshold)
self.return_fields = [
self.entry_id_field_name,
self.prompt_field_name,
self.response_field_name,
self.inserted_at_field_name,
self.updated_at_field_name,
self.metadata_field_name,
ENTRY_ID_FIELD_NAME,
PROMPT_FIELD_NAME,
RESPONSE_FIELD_NAME,
INSERTED_AT_FIELD_NAME,
UPDATED_AT_FIELD_NAME,
METADATA_FIELD_NAME,
]

# Create semantic cache schema and index
Expand Down Expand Up @@ -133,7 +134,7 @@ def __init__(

validate_vector_dims(
vectorizer.dims,
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)
self._vectorizer = vectorizer

Expand All @@ -145,9 +146,7 @@ def _modify_schema(
"""Modify the base cache schema using the provided filterable fields"""

if filterable_fields is not None:
protected_field_names = set(
self.return_fields + [self.redis_key_field_name]
)
protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME])
for filter_field in filterable_fields:
field_name = filter_field["name"]
if field_name in protected_field_names:
Expand Down Expand Up @@ -300,7 +299,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
doesn't match the search index vector dimensions."""
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
schema_vector_dims = self._index.schema.fields[CACHE_VECTOR_FIELD_NAME].attrs.dims # type: ignore
validate_vector_dims(len(vector), schema_vector_dims)

def check(
Expand Down Expand Up @@ -363,7 +362,7 @@ def check(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
Expand Down Expand Up @@ -444,7 +443,7 @@ async def acheck(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
Expand Down Expand Up @@ -479,7 +478,7 @@ def _process_cache_results(
cache_hit_dict = {
k: v for k, v in cache_hit_dict.items() if k in return_fields
}
cache_hit_dict[self.redis_key_field_name] = redis_key
cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key
cache_hits.append(cache_hit_dict)
return redis_keys, cache_hits

Expand Down Expand Up @@ -541,7 +540,7 @@ def store(
keys = self._index.load(
data=[cache_entry.to_dict()],
ttl=ttl,
id_field=self.entry_id_field_name,
id_field=ENTRY_ID_FIELD_NAME,
)
return keys[0]

Expand Down Expand Up @@ -605,7 +604,7 @@ async def astore(
keys = await aindex.load(
data=[cache_entry.to_dict()],
ttl=ttl,
id_field=self.entry_id_field_name,
id_field=ENTRY_ID_FIELD_NAME,
)
return keys[0]

Expand All @@ -629,21 +628,19 @@ def update(self, key: str, **kwargs) -> None:
for k, v in kwargs.items():

# Make sure the item is in the index schema
if k not in set(
self._index.schema.field_names + [self.metadata_field_name]
):
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
raise ValueError(f"{k} is not a valid field within the cache entry")

# Check for metadata and deserialize
if k == self.metadata_field_name:
if k == METADATA_FIELD_NAME:
if isinstance(v, dict):
kwargs[k] = serialize(v)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)

kwargs.update({self.updated_at_field_name: current_timestamp()})
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})

self._index.client.hset(key, mapping=kwargs) # type: ignore

Expand Down Expand Up @@ -674,21 +671,19 @@ async def aupdate(self, key: str, **kwargs) -> None:
for k, v in kwargs.items():

# Make sure the item is in the index schema
if k not in set(
self._index.schema.field_names + [self.metadata_field_name]
):
if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]):
raise ValueError(f"{k} is not a valid field within the cache entry")

# Check for metadata and deserialize
if k == self.metadata_field_name:
if k == METADATA_FIELD_NAME:
if isinstance(v, dict):
kwargs[k] = serialize(v)
else:
raise TypeError(
"If specified, cached metadata must be a dictionary."
)

kwargs.update({self.updated_at_field_name: current_timestamp()})
kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()})

await aindex.load(data=[kwargs], keys=[key])

Expand Down
17 changes: 8 additions & 9 deletions redisvl/extensions/session_manager/base_session.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from typing import Any, Dict, List, Optional, Union

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ROLE_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.extensions.session_manager.schema import ChatMessage
from redisvl.utils.utils import create_uuid


class BaseSessionManager:
id_field_name: str = "entry_id"
role_field_name: str = "role"
content_field_name: str = "content"
tool_field_name: str = "tool_call_id"
timestamp_field_name: str = "timestamp"
session_field_name: str = "session_tag"

def __init__(
self,
Expand Down Expand Up @@ -107,11 +106,11 @@ def _format_context(
context.append(chat_message.content)
else:
chat_message_dict = {
self.role_field_name: chat_message.role,
self.content_field_name: chat_message.content,
ROLE_FIELD_NAME: chat_message.role,
CONTENT_FIELD_NAME: chat_message.content,
}
if chat_message.tool_call_id is not None:
chat_message_dict[self.tool_field_name] = chat_message.tool_call_id
chat_message_dict[TOOL_FIELD_NAME] = chat_message.tool_call_id

context.append(chat_message_dict) # type: ignore

Expand Down
53 changes: 36 additions & 17 deletions redisvl/extensions/session_manager/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

from pydantic.v1 import BaseModel, Field, root_validator

from redisvl.extensions.constants import (
CONTENT_FIELD_NAME,
ID_FIELD_NAME,
ROLE_FIELD_NAME,
SESSION_FIELD_NAME,
SESSION_VECTOR_FIELD_NAME,
TIMESTAMP_FIELD_NAME,
TOOL_FIELD_NAME,
)
from redisvl.redis.utils import array_to_buffer
from redisvl.schema import IndexSchema
from redisvl.utils.utils import current_timestamp
Expand Down Expand Up @@ -31,18 +40,28 @@ class Config:
@root_validator(pre=True)
@classmethod
def generate_id(cls, values):
if "timestamp" not in values:
values["timestamp"] = current_timestamp()
if "entry_id" not in values:
values["entry_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
###if "timestamp" not in values:
### values["timestamp"] = current_timestamp()
if TIMESTAMP_FIELD_NAME not in values: ###
values[TIMESTAMP_FIELD_NAME] = current_timestamp() ###
###if "entry_id" not in values:
### values["entry_id"] = f'{values["session_tag"]}:{values["timestamp"]}'
if ID_FIELD_NAME not in values: ###
values[ID_FIELD_NAME] = (
f"{values[SESSION_FIELD_NAME]}:{values[TIMESTAMP_FIELD_NAME]}" ###
)
return values

def to_dict(self) -> Dict:
data = self.dict(exclude_none=True)

# handle optional fields
if "vector_field" in data:
data["vector_field"] = array_to_buffer(data["vector_field"])
###if "vector_field" in data:
### data["vector_field"] = array_to_buffer(data["vector_field"])
if SESSION_VECTOR_FIELD_NAME in data:
data[SESSION_VECTOR_FIELD_NAME] = array_to_buffer(
data[SESSION_VECTOR_FIELD_NAME]
)

return data

Expand All @@ -55,11 +74,11 @@ def from_params(cls, name: str, prefix: str):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": ROLE_FIELD_NAME, "type": "tag"},
{"name": CONTENT_FIELD_NAME, "type": "text"},
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
],
)

Expand All @@ -72,13 +91,13 @@ def from_params(cls, name: str, prefix: str, vectorizer_dims: int):
return cls(
index={"name": name, "prefix": prefix}, # type: ignore
fields=[ # type: ignore
{"name": "role", "type": "tag"},
{"name": "content", "type": "text"},
{"name": "tool_call_id", "type": "tag"},
{"name": "timestamp", "type": "numeric"},
{"name": "session_tag", "type": "tag"},
{"name": ROLE_FIELD_NAME, "type": "tag"},
{"name": CONTENT_FIELD_NAME, "type": "text"},
{"name": TOOL_FIELD_NAME, "type": "tag"},
{"name": TIMESTAMP_FIELD_NAME, "type": "numeric"},
{"name": SESSION_FIELD_NAME, "type": "tag"},
{
"name": "vector_field",
"name": SESSION_VECTOR_FIELD_NAME,
"type": "vector",
"attrs": {
"dims": vectorizer_dims,
Expand Down
Loading