-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement redis cache mode, if redis_url is set in the llm_config then
it will try to use this. also adds a test to validate both the existing and the redis cache behavior.
- Loading branch information
1 parent
2e519b0
commit f5f9d39
Showing
8 changed files
with
231 additions
and
4 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
class AbstractCache( ABC ): | ||
"""Abstract base class for cache implementations.""" | ||
|
||
@abstractmethod | ||
def get(self, key, default=None): | ||
pass | ||
|
||
@abstractmethod | ||
def set(self, key, value): | ||
pass | ||
|
||
@abstractmethod | ||
def close(self): | ||
pass | ||
|
||
@abstractmethod | ||
def __enter__(self): | ||
pass | ||
|
||
@abstractmethod | ||
def __exit__(self, exc_type, exc_value, traceback): | ||
pass | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from autogen.cache.disk_cache import DiskCache | ||
|
||
try: | ||
from autogen.cache.redis_cache import RedisCache | ||
except ImportError: | ||
RedisCache = None | ||
|
||
|
||
def cache_factory(seed, redis_url): | ||
"""Factory function for creating cache instances. | ||
If redis_url is not None, use RedisCache, otherwise use DiskCache""" | ||
if RedisCache is not None and redis_url is not None: | ||
return RedisCache(seed, redis_url) | ||
return DiskCache(f"./cache/{seed}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import diskcache | ||
|
||
from .abstract_cache_base import AbstractCache | ||
|
||
|
||
class DiskCache(AbstractCache): | ||
"""DiskCache implementation of the AbstractCache.""" | ||
|
||
def __init__(self, seed): | ||
self.cache = diskcache.Cache(seed) | ||
|
||
def get(self, key, default=None): | ||
return self.cache.get(key, default) | ||
|
||
def set(self, key, value): | ||
self.cache.set(key, value) | ||
|
||
def close(self): | ||
self.cache.close() | ||
|
||
def __enter__(self): | ||
# Return the object itself when entering the context | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
# Clean up resources and handle exceptions if necessary | ||
self.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import pickle | ||
|
||
import redis | ||
|
||
from .abstract_cache_base import AbstractCache | ||
|
||
|
||
class RedisCache(AbstractCache): | ||
"""RedisCache implementation of the AbstractCache.""" | ||
|
||
def __init__(self, seed, redis_url): | ||
# Initialize Redis client | ||
self.seed = seed | ||
self.cache = redis.Redis.from_url(redis_url) | ||
|
||
def _prefixed_key(self, key): | ||
# Prefix the key with the seed | ||
return f"autogen:{self.seed}:{key}" | ||
|
||
def get(self, key, default=None): | ||
# Get element from Redis cache and deserialize | ||
result = self.cache.get(self._prefixed_key(key)) | ||
if result is None: | ||
return default | ||
return pickle.loads(result) | ||
|
||
def set(self, key, value): | ||
# Serialize the element and store to Redis cache | ||
serialized_value = pickle.dumps(value, protocol=0) | ||
self.cache.set(self._prefixed_key(key), serialized_value) | ||
|
||
def close(self): | ||
# Close Redis client | ||
self.cache.close() | ||
|
||
def __enter__(self): | ||
# Return the object itself when entering the context | ||
return self | ||
|
||
def __exit__(self, exc_type, exc_value, traceback): | ||
# Clean up resources and handle exceptions if necessary | ||
self.close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import os | ||
import sys | ||
import time | ||
|
||
import pytest | ||
import autogen | ||
from autogen.agentchat import AssistantAgent, UserProxyAgent | ||
sys.path.append(os.path.join(os.path.dirname(__file__), "..")) | ||
from conftest import skip_openai, skip_redis # noqa: E402 | ||
|
||
try: | ||
from openai import OpenAI | ||
except ImportError: | ||
skip = True | ||
else: | ||
skip = False or skip_openai | ||
|
||
try: | ||
import redis | ||
except ImportError: | ||
skip_redis = True | ||
else: | ||
skip_redis = False or skip_redis | ||
|
||
@pytest.mark.skipif(skip, reason="openai not installed OR requested to skip") | ||
def test_disk_cache(human_input_mode="NEVER", max_consecutive_auto_reply=5): | ||
random_cache_seed = int.from_bytes(os.urandom(2), "big") | ||
start_time = time.time() | ||
cold_cache_messages = run_conversation(cache_seed=random_cache_seed, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply) | ||
end_time = time.time() | ||
duration_with_cold_cache = end_time - start_time | ||
|
||
start_time = time.time() | ||
warm_cache_messages = run_conversation(cache_seed=random_cache_seed, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply) | ||
end_time = time.time() | ||
duration_with_warm_cache = end_time - start_time | ||
assert cold_cache_messages == warm_cache_messages | ||
assert duration_with_warm_cache < duration_with_cold_cache | ||
|
||
@pytest.mark.skipif(skip_redis, reason="redis not installed OR requested to skip") | ||
def test_redis_cache(human_input_mode="NEVER", max_consecutive_auto_reply=5): | ||
random_cache_seed = int.from_bytes(os.urandom(2), "big") | ||
redis_url = os.getenv("REDIS_URL", "redis://localhost:6379/0") | ||
start_time = time.time() | ||
cold_cache_messages = run_conversation(cache_seed=random_cache_seed, redis_url=redis_url, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply) | ||
end_time = time.time() | ||
duration_with_cold_cache = end_time - start_time | ||
|
||
start_time = time.time() | ||
warm_cache_messages = run_conversation(cache_seed=random_cache_seed, redis_url=redis_url, human_input_mode=human_input_mode, max_consecutive_auto_reply=max_consecutive_auto_reply) | ||
end_time = time.time() | ||
duration_with_warm_cache = end_time - start_time | ||
assert cold_cache_messages == warm_cache_messages | ||
assert duration_with_warm_cache < duration_with_cold_cache | ||
def run_conversation(cache_seed, redis_url=None, human_input_mode="NEVER", max_consecutive_auto_reply=5): | ||
KEY_LOC = "notebook" | ||
OAI_CONFIG_LIST = "OAI_CONFIG_LIST" | ||
here = os.path.abspath(os.path.dirname(__file__)) | ||
config_list = autogen.config_list_from_json( | ||
OAI_CONFIG_LIST, | ||
file_location=KEY_LOC, | ||
filter_dict={ | ||
"model": { | ||
"gpt-3.5-turbo", | ||
"gpt-35-turbo", | ||
"gpt-3.5-turbo-16k", | ||
"gpt-3.5-turbo-16k-0613", | ||
"gpt-3.5-turbo-0301", | ||
"chatgpt-35-turbo-0301", | ||
"gpt-35-turbo-v0301", | ||
"gpt", | ||
}, | ||
}, | ||
) | ||
llm_config = { | ||
"cache_seed": cache_seed, | ||
"redis_url": redis_url, | ||
"config_list": config_list, | ||
"max_tokens": 1024, | ||
} | ||
assistant = AssistantAgent( | ||
"coding_agent", | ||
llm_config=llm_config, | ||
) | ||
user = UserProxyAgent( | ||
"user", | ||
human_input_mode=human_input_mode, | ||
is_termination_msg=lambda x: x.get("content", "").rstrip().endswith("TERMINATE"), | ||
max_consecutive_auto_reply=max_consecutive_auto_reply, | ||
code_execution_config={ | ||
"work_dir": f"{here}/test_agent_scripts", | ||
"use_docker": "python:3", | ||
"timeout": 60, | ||
}, | ||
llm_config=llm_config, | ||
system_message="""Is code provided but not enclosed in ``` blocks? | ||
If so, remind that code blocks need to be enclosed in ``` blocks. | ||
Reply TERMINATE to end the conversation if the task is finished. Don't say appreciation. | ||
If "Thank you" or "You\'re welcome" are said in the conversation, then say TERMINATE and that is your last message.""", | ||
) | ||
user.initiate_chat(assistant, message="TERMINATE") | ||
# should terminate without sending any message | ||
assert assistant.last_message()["content"] == assistant.last_message(user)["content"] == "TERMINATE" | ||
coding_task = "Print hello world to a file called hello.txt" | ||
|
||
# track how long this takes | ||
user.initiate_chat(assistant, message=coding_task) | ||
return user.chat_messages[list(user.chat_messages.keys())[-0]] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters