Skip to content

Commit

Permalink
implement redis cache mode, if redis_url is set in the llm_config then
Browse files Browse the repository at this point in the history
it will try to use this.  also adds a test to validate both the existing
and the redis cache behavior.
  • Loading branch information
vijaykramesh committed Jan 12, 2024
1 parent 2e519b0 commit f5f9d39
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 4 deletions.
Empty file added autogen/cache/__init__.py
Empty file.
25 changes: 25 additions & 0 deletions autogen/cache/abstract_cache_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC, abstractmethod

class AbstractCache( ABC ):
"""Abstract base class for cache implementations."""

@abstractmethod
def get(self, key, default=None):
pass

@abstractmethod
def set(self, key, value):
pass

@abstractmethod
def close(self):
pass

@abstractmethod
def __enter__(self):
pass

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
pass

14 changes: 14 additions & 0 deletions autogen/cache/cache_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from autogen.cache.disk_cache import DiskCache

try:
from autogen.cache.redis_cache import RedisCache
except ImportError:
RedisCache = None


def cache_factory(seed, redis_url):
"""Factory function for creating cache instances.
If redis_url is not None, use RedisCache, otherwise use DiskCache"""
if RedisCache is not None and redis_url is not None:
return RedisCache(seed, redis_url)
return DiskCache(f"./cache/{seed}")
27 changes: 27 additions & 0 deletions autogen/cache/disk_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import diskcache

from .abstract_cache_base import AbstractCache


class DiskCache(AbstractCache):
"""DiskCache implementation of the AbstractCache."""

def __init__(self, seed):
self.cache = diskcache.Cache(seed)

def get(self, key, default=None):
return self.cache.get(key, default)

def set(self, key, value):
self.cache.set(key, value)

def close(self):
self.cache.close()

def __enter__(self):
# Return the object itself when entering the context
return self

def __exit__(self, exc_type, exc_value, traceback):
# Clean up resources and handle exceptions if necessary
self.close()
42 changes: 42 additions & 0 deletions autogen/cache/redis_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pickle

import redis

from .abstract_cache_base import AbstractCache


class RedisCache(AbstractCache):
"""RedisCache implementation of the AbstractCache."""

def __init__(self, seed, redis_url):
# Initialize Redis client
self.seed = seed
self.cache = redis.Redis.from_url(redis_url)

def _prefixed_key(self, key):
# Prefix the key with the seed
return f"autogen:{self.seed}:{key}"

def get(self, key, default=None):
# Get element from Redis cache and deserialize
result = self.cache.get(self._prefixed_key(key))
if result is None:
return default
return pickle.loads(result)

def set(self, key, value):
# Serialize the element and store to Redis cache
serialized_value = pickle.dumps(value, protocol=0)
self.cache.set(self._prefixed_key(key), serialized_value)

def close(self):
# Close Redis client
self.cache.close()

def __enter__(self):
# Return the object itself when entering the context
return self

def __exit__(self, exc_type, exc_value, traceback):
# Clean up resources and handle exceptions if necessary
self.close()
16 changes: 12 additions & 4 deletions autogen/oai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pydantic import BaseModel

from autogen.cache.cache_factory import cache_factory
from autogen.oai import completion

from autogen.oai.openai_utils import get_key, OAI_PRICE1K
Expand All @@ -34,7 +35,9 @@
)
from openai.types.completion import Completion
from openai.types.completion_usage import CompletionUsage
import diskcache

# cache wrapper
from autogen.cache.disk_cache import DiskCache

if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
Expand All @@ -52,7 +55,7 @@ class OpenAIWrapper:
"""A wrapper class for openai client."""

cache_path_root: str = ".cache"
extra_kwargs = {"cache_seed", "filter_func", "allow_format_str_template", "context", "api_version"}
extra_kwargs = {"cache_seed", "redis_url", "filter_func", "allow_format_str_template", "context", "api_version"}
openai_kwargs = set(inspect.getfullargspec(OpenAI.__init__).kwonlyargs)
total_usage_summary: Optional[Dict[str, Any]] = None
actual_usage_summary: Optional[Dict[str, Any]] = None
Expand Down Expand Up @@ -221,6 +224,10 @@ def create(self, **config: Any) -> ChatCompletion:
- `cache_seed` (int | None) for the cache. Default to 41.
An integer cache_seed is useful when implementing "controlled randomness" for the completion.
None for no caching.
- `redis_url` (str | None) for the redis cache. Default to None.
A string redis_url formatted like "redis://:password@localhost:6379/0" will turn on the redis cache.
None for no redis cache. If `cache_seed` is None, redis_url will be ignored
You must install redis to use redis cache.
- filter_func (Callable | None): A function that takes in the context and the response
and returns a boolean to indicate whether the response is valid. E.g.,
Expand Down Expand Up @@ -248,12 +255,13 @@ def yes_or_no_filter(context, response):
params = self._construct_create_params(create_config, extra_kwargs)
# get the cache_seed, filter_func and context
cache_seed = extra_kwargs.get("cache_seed", 41)
redis_url = extra_kwargs.get("redis_url", None)
filter_func = extra_kwargs.get("filter_func")
context = extra_kwargs.get("context")

# Try to load the response from cache
if cache_seed is not None:
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
with cache_factory(f"{self.cache_path_root}/{cache_seed}", redis_url) as cache:
# Try to get the response from cache
key = get_key(params)
response: ChatCompletion = cache.get(key, None)
Expand Down Expand Up @@ -290,7 +298,7 @@ def yes_or_no_filter(context, response):
self._update_usage_summary(response, use_cache=False)
if cache_seed is not None:
# Cache the response
with diskcache.Cache(f"{self.cache_path_root}/{cache_seed}") as cache:
with cache_factory(f"{self.cache_path_root}/{cache_seed}", redis_url) as cache:
cache.set(key, response)

# check the filter
Expand Down
109 changes: 109 additions & 0 deletions test/agentchat/test_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
import os
import sys
import time

import pytest
import autogen
from autogen.agentchat import AssistantAgent, UserProxyAgent
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
from conftest import skip_openai, skip_redis # noqa: E402

try:
from openai import OpenAI
except ImportError:
skip = True
else:
skip = False or skip_openai

try:
import redis
except ImportError:
skip_redis = True
else:
skip_redis = False or skip_redis

@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip")
def test_disk_cache(human_input_mode="NEVER", max_consecutive_auto_reply=5):
random_cache_seed = int.from_bytes(os.urandom(2), "big")
start_time = time.time()
cold_cache_messages = run_conversation(cache_seed=random_cache_seed, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply)
end_time = time.time()
duration_with_cold_cache = end_time - start_time

start_time = time.time()
warm_cache_messages = run_conversation(cache_seed=random_cache_seed, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache

@pytest.mark.skipif(skip_redis, reason="redis not installed OR requested to skip")
def test_redis_cache(human_input_mode="NEVER", max_consecutive_auto_reply=5):
random_cache_seed = int.from_bytes(os.urandom(2), "big")
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0")
start_time = time.time()
cold_cache_messages = run_conversation(cache_seed=random_cache_seed, redis_url=redis_url, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply)
end_time = time.time()
duration_with_cold_cache = end_time - start_time

start_time = time.time()
warm_cache_messages = run_conversation(cache_seed=random_cache_seed, redis_url=redis_url, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply)
end_time = time.time()
duration_with_warm_cache = end_time - start_time
assert cold_cache_messages == warm_cache_messages
assert duration_with_warm_cache < duration_with_cold_cache
def run_conversation(cache_seed, redis_url=None, human_input_mode="NEVER", max_consecutive_auto_reply=5):
KEY_LOC = "notebook"
OAI_CONFIG_LIST = "OAI_CONFIG_LIST"
here = os.path.abspath(os.path.dirname(__file__))
config_list = autogen.config_list_from_json(
OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={
"model": {
"gpt-3.5-turbo",
"gpt-35-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
"chatgpt-35-turbo-0301",
"gpt-35-turbo-v0301",
"gpt",
},
},
)
llm_config = {
"cache_seed": cache_seed,
"redis_url": redis_url,
"config_list": config_list,
"max_tokens": 1024,
}
assistant = AssistantAgent(
"coding_agent",
llm_config=llm_config,
)
user = UserProxyAgent(
"user",
human_input_mode=human_input_mode,
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"),
max_consecutive_auto_reply=max_consecutive_auto_reply,
code_execution_config={
"work_dir": f"{here}/test_agent_scripts",
"use_docker": "python:3",
"timeout": 60,
},
llm_config=llm_config,
system_message="""Is code provided but not enclosed in ``` blocks?
If so, remind that code blocks need to be enclosed in ``` blocks.
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation.
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""",
)
user.initiate_chat(assistant, message="TERMINATE")
# should terminate without sending any message
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE"
coding_task = "Print hello world to a file called hello.txt"

# track how long this takes
user.initiate_chat(assistant, message=coding_task)
return user.chat_messages[list(user.chat_messages.keys())[-0]]

2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ def pytest_addoption(parser):
def pytest_configure(config):
global skip_openai
skip_openai = config.getoption("--skip-openai", False)
global skip_redis
skip_redis = config.getoption("--skip-redis", False)

0 comments on commit f5f9d39

Please sign in to comment.