Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .circleci/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tiktoken
importlib_metadata
cohere
redis
redisvl==0.3.2
anthropic
orjson==3.9.15
pydantic==2.7.1
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ RUN pip install dist/*.whl
# install dependencies as wheels
RUN pip wheel --no-cache-dir --wheel-dir=/wheels/ -r requirements.txt

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile.database
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
3 changes: 0 additions & 3 deletions Dockerfile.non_root
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ COPY --from=builder /wheels/ /wheels/
# Install the built wheel using pip; again using a wildcard if it's the only file
RUN pip install *.whl /wheels/* --no-index --find-links=/wheels/ && rm -f *.whl && rm -rf /wheels

# install semantic-cache [Experimental]- we need this here and not in requirements.txt because redisvl pins to pydantic 1.0
RUN pip install redisvl==0.0.7 --no-deps

# ensure pyjwt is used, not jwt
RUN pip uninstall jwt -y
RUN pip uninstall PyJWT -y
Expand Down
2 changes: 1 addition & 1 deletion docs/my-website/docs/caching/all_caches.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ response2 = completion(

Install redis
```shell
pip install redisvl==0.0.7
pip install redisvl==0.3.2
```

For the hosted version you can setup your own Redis DB here: https://app.redislabs.com/
Expand Down
223 changes: 82 additions & 141 deletions litellm/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,8 @@ def delete_cache(self, key):


class RedisSemanticCache(BaseCache):
DEFAULT_REDIS_INDEX_NAME = "litellm_semantic_cache_index"

def __init__(
self,
host=None,
Expand All @@ -922,38 +924,26 @@ def __init__(
similarity_threshold=None,
use_async=False,
embedding_model="text-embedding-ada-002",
index_name=None,
**kwargs,
):
from redisvl.index import SearchIndex
from redisvl.query import VectorQuery
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer

if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME

print_verbose(
"redis semantic-cache initializing INDEX - litellm_semantic_cache_index"
f"redis semantic-cache initializing INDEX - {index_name}"
)

if similarity_threshold is None:
raise Exception("similarity_threshold must be provided, passed None")

self.similarity_threshold = similarity_threshold
self.distance_threshold = 1 - similarity_threshold
self.embedding_model = embedding_model
schema = {
"index": {
"name": "litellm_semantic_cache_index",
"prefix": "litellm",
"storage_type": "hash",
},
"fields": {
"text": [{"name": "response"}],
"text": [{"name": "prompt"}],
"vector": [
{
"name": "litellm_embedding",
"dims": 1536,
"distance_metric": "cosine",
"algorithm": "flat",
"datatype": "float32",
}
],
},
}

if redis_url is None:
# if no url passed, check if host, port and password are passed, if not raise an Exception
if host is None or port is None or password is None:
Expand All @@ -967,20 +957,29 @@ def __init__(
raise Exception("Redis host, port, and password must be provided")

redis_url = "redis://:" + password + "@" + host + ":" + port

print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url)
try:
self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async == True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)

#
def generate_cache_embeddings(prompt: str):
# create an embedding from prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
embedding = embedding_response["data"][0]["embedding"]
return embedding

cache_vectorizer = CustomTextVectorizer(generate_cache_embeddings)

self.llmcache = SemanticCache(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False
)

def _get_cache_logic(self, cached_response: Any):
"""
Common 'get_cache_logic' across sync + async redis client implementations
Expand All @@ -998,111 +997,70 @@ def _get_cache_logic(self, cached_response: Any):
) # Convert string to dictionary
except:
cached_response = ast.literal_eval(cached_response)

return cached_response

def set_cache(self, key, value, **kwargs):
import numpy as np

print_verbose(f"redis semantic-cache set_cache, kwargs: {kwargs}")

# get the prompt
# get the prompt and value
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)

# create an embedding for prompt
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)

# get the embedding
embedding = embedding_response["data"][0]["embedding"]

# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)

new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]

# Add more data
keys = self.index.load(new_data)
# store in redis semantic cache
self.llmcache.store(
prompt=prompt,
response=value
)

return

def get_cache(self, key, **kwargs):
print_verbose(f"sync redis semantic-cache get_cache, kwargs: {kwargs}")
import numpy as np
from redisvl.query import VectorQuery

# query
# get the messages
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)

# convert to embedding
embedding_response = litellm.embedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)

# get the embedding
embedding = embedding_response["data"][0]["embedding"]

query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
num_results=1,
)
# check the cache
results = self.llmcache.check(prompt=prompt)

results = self.index.query(query)
if results == None:
# handle results / cache hit
if not results:
return None
if isinstance(results, list):
if len(results) == 0:
return None

vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]

# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
f"got a cache hit: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, current_prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
return self._get_cache_logic(cached_response=cached_response)

pass

async def async_set_cache(self, key, value, **kwargs):
import numpy as np
# TODO - patch async support in redisvl for SemanticCache

from litellm.proxy.proxy_server import llm_model_list, llm_router

try:
await self.index.acreate(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")

print_verbose(f"async redis semantic-cache set_cache, kwargs: {kwargs}")

# get the prompt
# get the prompt and value
messages = kwargs["messages"]
prompt = "".join(message["content"] for message in messages)
value = str(value)
assert isinstance(value, str)

# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
Expand Down Expand Up @@ -1132,23 +1090,19 @@ async def async_set_cache(self, key, value, **kwargs):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]

# make the embedding a numpy array, convert to bytes
embedding_bytes = np.array(embedding, dtype=np.float32).tobytes()
value = str(value)
assert isinstance(value, str)

new_data = [
{"response": value, "prompt": prompt, "litellm_embedding": embedding_bytes}
]
# store in redis semantic cache
self.llmcache.store(
prompt=prompt,
response=value,
vector=embedding # pass through custom embedding here
)

# Add more data
keys = await self.index.aload(new_data)
return

async def async_get_cache(self, key, **kwargs):
# TODO - patch async support in redisvl for SemanticCache

print_verbose(f"async redis semantic-cache get_cache, kwargs: {kwargs}")
import numpy as np
from redisvl.query import VectorQuery

from litellm.proxy.proxy_server import llm_model_list, llm_router

Expand Down Expand Up @@ -1185,47 +1139,34 @@ async def async_get_cache(self, key, **kwargs):
# get the embedding
embedding = embedding_response["data"][0]["embedding"]

query = VectorQuery(
vector=embedding,
vector_field_name="litellm_embedding",
return_fields=["response", "prompt", "vector_distance"],
# check the cache
results = self.llmcache.check(
prompt=prompt, vector=embedding
)
results = await self.index.aquery(query)
if results == None:

# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None

vector_distance = results[0]["vector_distance"]
vector_distance = float(vector_distance)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

cache_hit = results[0]
vector_distance = float(cache_hit["vector_distance"])
similarity = 1 - vector_distance
cached_prompt = results[0]["prompt"]
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]

# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
f"got a cache hit: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, current_prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_response)

# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity

if similarity > self.similarity_threshold:
# cache hit !
cached_value = results[0]["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
return self._get_cache_logic(cached_response=cached_value)
else:
# cache miss !
return None
pass

async def _index_info(self):
return await self.index.ainfo()
# TODO - patch async support in redisvl for SemanticCache
return self.llmcache.index.info()


class QdrantSemanticCache(BaseCache):
Expand Down
Loading