Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] implement redis cache mode #1222

Merged
merged 17 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
if: matrix.python-version == '3.10'
run: |
pip install -e .[test]
pip install -e .[redis]
coverage run -a -m pytest test --ignore=test/agentchat/contrib --skip-openai
coverage xml
- name: Upload coverage to Codecov
Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/openai.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ jobs:
python-version: ["3.9", "3.10", "3.11", "3.12"]
runs-on: ${{ matrix.os }}
environment: openai1
services:
redis:
image: redis
ports:
- 6379:6379
options: --entrypoint redis-server
steps:
# checkout to pr branch
- name: Checkout
Expand All @@ -42,6 +48,7 @@ jobs:
if: matrix.python-version == '3.9'
run: |
pip install docker
pip install -e .[redis]
- name: Coverage
if: matrix.python-version == '3.9'
env:
Expand Down
23 changes: 22 additions & 1 deletion autogen/agentchat/conversable_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Tuple, Type, TypeVar, Union

from .. import OpenAIWrapper
from ..cache.cache import Cache
from ..code_utils import (
DEFAULT_MODEL,
UNKNOWN,
Expand Down Expand Up @@ -135,6 +136,9 @@ def __init__(
self.llm_config.update(llm_config)
self.client = OpenAIWrapper(**self.llm_config)

# Initialize standalone client cache object.
self.client_cache = None

self._code_execution_config: Union[Dict, Literal[False]] = (
{} if code_execution_config is None else code_execution_config
)
Expand Down Expand Up @@ -665,6 +669,7 @@ def initiate_chat(
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""Initiate a chat with the recipient agent.
Expand All @@ -677,6 +682,7 @@ def initiate_chat(
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
Expand All @@ -686,14 +692,20 @@ def initiate_chat(
"""
for agent in [self, recipient]:
agent._raise_exception_on_async_reply_functions()
agent.previous_cache = agent.client_cache
agent.client_cache = cache
self._prepare_chat(recipient, clear_history)
self.send(self.generate_init_message(**context), recipient, silent=silent)
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None

async def a_initiate_chat(
self,
recipient: "ConversableAgent",
clear_history: Optional[bool] = True,
silent: Optional[bool] = False,
cache: Optional[Cache] = None,
**context,
):
"""(async) Initiate a chat with the recipient agent.
Expand All @@ -706,12 +718,19 @@ async def a_initiate_chat(
recipient: the recipient agent.
clear_history (bool): whether to clear the chat history with the agent.
silent (bool or None): (Experimental) whether to print the messages for this conversation.
cache (Cache or None): the cache client to be used for this conversation.
**context: any context information.
"message" needs to be provided if the `generate_init_message` method is not overridden.
Otherwise, input() will be called to get the initial message.
"""
self._prepare_chat(recipient, clear_history)
for agent in [self, recipient]:
agent.previous_cache = agent.client_cache
agent.client_cache = cache
await self.a_send(await self.a_generate_init_message(**context), recipient, silent=silent)
for agent in [self, recipient]:
agent.client_cache = agent.previous_cache
agent.previous_cache = None

def reset(self):
"""Reset the agent."""
Expand Down Expand Up @@ -778,7 +797,9 @@ def generate_oai_reply(

# TODO: #1143 handle token limit exceeded error
response = client.create(
context=messages[-1].pop("context", None), messages=self._oai_system_message + all_messages
context=messages[-1].pop("context", None),
messages=self._oai_system_message + all_messages,
cache=self.client_cache,
)

extracted_response = client.extract_text_or_completion_object(response)[0]
Expand Down
18 changes: 17 additions & 1 deletion autogen/agentchat/groupchat.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,13 +349,17 @@ def run_chat(
messages: Optional[List[Dict]] = None,
sender: Optional[Agent] = None,
config: Optional[GroupChat] = None,
) -> Union[str, Dict, None]:
) -> Tuple[bool, Optional[str]]:
"""Run a group chat."""
if messages is None:
messages = self._oai_messages[sender]
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)
if self._is_termination_msg(message):
Expand Down Expand Up @@ -389,6 +393,10 @@ def run_chat(
message = self.last_message(speaker)
if i == groupchat.max_round - 1:
groupchat.append(message, speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None

async def a_run_chat(
ekzhu marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -403,6 +411,10 @@ async def a_run_chat(
message = messages[-1]
speaker = sender
groupchat = config
if self.client_cache is not None:
for a in groupchat.agents:
a.previous_cache = a.client_cache
a.client_cache = self.client_cache
for i in range(groupchat.max_round):
groupchat.append(message, speaker)

Expand Down Expand Up @@ -436,6 +448,10 @@ async def a_run_chat(
# The speaker sends the message without requesting a reply
await speaker.a_send(reply, self, request_reply=False)
message = self.last_message(speaker)
if self.client_cache is not None:
for a in groupchat.agents:
a.client_cache = a.previous_cache
a.previous_cache = None
return True, None

def _raise_exception_on_async_reply_functions(self) -> None:
Expand Down
Empty file added autogen/cache/__init__.py
Empty file.
90 changes: 90 additions & 0 deletions autogen/cache/abstract_cache_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from abc import ABC, abstractmethod


class AbstractCache(ABC):
"""
Abstract base class for cache implementations.

This class defines the basic interface for cache operations.
Implementing classes should provide concrete implementations for
these methods to handle caching mechanisms.
"""

@abstractmethod
def get(self, key, default=None):
"""
Retrieve an item from the cache.

Abstract method that must be implemented by subclasses to
retrieve an item from the cache.

Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.

Returns:
The value associated with the key if found, else the default value.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def set(self, key, value):
"""
Set an item in the cache.

Abstract method that must be implemented by subclasses to
store an item in the cache.

Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def close(self):
"""
Close the cache.

Abstract method that should be implemented by subclasses to
perform any necessary cleanup, such as closing network connections or
releasing resources.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def __enter__(self):
"""
Enter the runtime context related to this object.

The with statement will bind this method’s return value to the target(s)
specified in the as clause of the statement, if any.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""

@abstractmethod
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context and close the cache.

Abstract method that should be implemented by subclasses to handle
the exit from a with statement. It is responsible for resource
release and cleanup.

Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.

Raises:
NotImplementedError: If the subclass does not implement this method.
"""
137 changes: 137 additions & 0 deletions autogen/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
from typing import Dict, Any

from autogen.cache.cache_factory import CacheFactory


class Cache:
"""
A wrapper class for managing cache configuration and instances.

This class provides a unified interface for creating and interacting with
different types of cache (e.g., Redis, Disk). It abstracts the underlying
cache implementation details, providing methods for cache operations.

Attributes:
config (Dict[str, Any]): A dictionary containing cache configuration.
cache: The cache instance created based on the provided configuration.

Methods:
redis(cache_seed=42, redis_url="redis://localhost:6379/0"): Static method to create a Redis cache instance.
disk(cache_seed=42, cache_path_root=".cache"): Static method to create a Disk cache instance.
__init__(self, config): Initializes the Cache with the given configuration.
__enter__(self): Context management entry, returning the cache instance.
__exit__(self, exc_type, exc_value, traceback): Context management exit.
get(self, key, default=None): Retrieves an item from the cache.
set(self, key, value): Sets an item in the cache.
close(self): Closes the cache.
"""

ALLOWED_CONFIG_KEYS = ["cache_seed", "redis_url", "cache_path_root"]

@staticmethod
def redis(cache_seed=42, redis_url="redis://localhost:6379/0"):
"""
Create a Redis cache instance.

Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
redis_url (str, optional): The URL for the Redis server. Defaults to "redis://localhost:6379/0".

Returns:
Cache: A Cache instance configured for Redis.
"""
return Cache({"cache_seed": cache_seed, "redis_url": redis_url})

@staticmethod
def disk(cache_seed=42, cache_path_root=".cache"):
"""
Create a Disk cache instance.

Args:
cache_seed (int, optional): A seed for the cache. Defaults to 42.
cache_path_root (str, optional): The root path for the disk cache. Defaults to ".cache".

Returns:
Cache: A Cache instance configured for Disk caching.
"""
return Cache({"cache_seed": cache_seed, "cache_path_root": cache_path_root})

def __init__(self, config: Dict[str, Any]):
"""
Initialize the Cache with the given configuration.

Validates the configuration keys and creates the cache instance.

Args:
config (Dict[str, Any]): A dictionary containing the cache configuration.

Raises:
ValueError: If an invalid configuration key is provided.
"""
self.config = config
# validate config
for key in self.config.keys():
if key not in self.ALLOWED_CONFIG_KEYS:
raise ValueError(f"Invalid config key: {key}")
# create cache instance
self.cache = CacheFactory.cache_factory(
self.config.get("cache_seed", "42"),
self.config.get("redis_url", None),
self.config.get("cache_path_root", None),
)

def __enter__(self):
"""
Enter the runtime context related to the cache object.

Returns:
The cache instance for use within a context block.
"""
return self.cache.__enter__()

def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the runtime context related to the cache object.

Cleans up the cache instance and handles any exceptions that occurred
within the context.

Args:
exc_type: The exception type if an exception was raised in the context.
exc_value: The exception value if an exception was raised in the context.
traceback: The traceback if an exception was raised in the context.
"""
return self.cache.__exit__(exc_type, exc_value, traceback)

def get(self, key, default=None):
"""
Retrieve an item from the cache.

Args:
key (str): The key identifying the item in the cache.
default (optional): The default value to return if the key is not found.
Defaults to None.

Returns:
The value associated with the key if found, else the default value.
"""
return self.cache.get(key, default)

def set(self, key, value):
"""
Set an item in the cache.

Args:
key (str): The key under which the item is to be stored.
value: The value to be stored in the cache.
"""
self.cache.set(key, value)

def close(self):
"""
Close the cache.

Perform any necessary cleanup, such as closing connections or releasing resources.
"""
self.cache.close()
Loading
Loading