Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@

try:
from redis import Redis
from redisvl.extensions.message_history import SemanticMessageHistory
from redisvl.extensions.message_history import MessageHistory, SemanticMessageHistory
from redisvl.utils.utils import deserialize, serialize
from redisvl.utils.vectorize import HFTextVectorizer
except ImportError as e:
raise ImportError("To use Redis Memory RedisVL must be installed. Run `pip install autogen-ext[redisvl]`") from e

Expand All @@ -29,24 +30,25 @@ class RedisMemoryConfig(BaseModel):
redis_url: str = Field(default="redis://localhost:6379", description="url of the Redis instance")
index_name: str = Field(default="chat_history", description="Name of the Redis collection")
prefix: str = Field(default="memory", description="prefix of the Redis collection")
sequential: bool = Field(
default=False, description="ignore semantic similarity and simply return memories in sequential order"
)
distance_metric: Literal["cosine", "ip", "l2"] = "cosine"
algorithm: Literal["flat", "hnsw"] = "flat"
top_k: int = Field(default=10, description="Number of results to return in queries")
datatype: Literal["uint8", "int8", "float16", "float32", "float64", "bfloat16"] = "float32"
distance_threshold: float = Field(default=0.7, description="Minimum similarity score threshold")
model_name: str | None = Field(
default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name"
)
model_name: str = Field(default="sentence-transformers/all-mpnet-base-v2", description="Embedding model name")


class RedisMemory(Memory, Component[RedisMemoryConfig]):
"""
Store and retrieve memory using vector similarity search powered by RedisVL.

`RedisMemory` provides a vector-based memory implementation that uses RedisVL for storing and
retrieving content based on semantic similarity. It enhances agents with the ability to recall
contextually relevant information during conversations by leveraging vector embeddings to find
similar content.
retrieving content based on semantic similarity or sequential order. It enhances agents with the
ability to recall relevant information during conversations by leveraging vector embeddings to
find similar content.

This implementation requires the RedisVL extra to be installed. Install with:

Expand Down Expand Up @@ -175,7 +177,19 @@ def __init__(self, config: RedisMemoryConfig | None = None) -> None:
self.config = config or RedisMemoryConfig()
client = Redis.from_url(url=self.config.redis_url) # type: ignore[reportUknownMemberType]

self.message_history = SemanticMessageHistory(name=self.config.index_name, redis_client=client)
if self.config.sequential:
self.message_history = MessageHistory(
name=self.config.index_name, prefix=self.config.prefix, redis_client=client
)
else:
vectorizer = HFTextVectorizer(model=self.config.model_name, dtype=self.config.datatype)
self.message_history = SemanticMessageHistory(
name=self.config.index_name,
prefix=self.config.prefix,
vectorizer=vectorizer,
distance_threshold=self.config.distance_threshold,
redis_client=client,
)

async def update_context(
self,
Expand Down Expand Up @@ -203,7 +217,7 @@ async def update_context(
else:
last_message = ""

query_results = await self.query(last_message)
query_results = await self.query(last_message, sequential=self.config.sequential)

stringified_messages = "\n\n".join([str(m.content) for m in query_results.results])

Expand All @@ -216,10 +230,10 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke

.. note::

To perform semantic search over stored memories RedisMemory creates a vector embedding
from the content field of a MemoryContent object. This content is assumed to be text,
JSON, or Markdown, and is passed to the vector embedding model specified in
RedisMemoryConfig.
If RedisMemoryConfig is not set to 'sequential', to perform semantic search over stored
memories RedisMemory creates a vector embedding from the content field of a
MemoryContent object. This content is assumed to be text, JSON, or Markdown, and is
passed to the vector embedding model specified in RedisMemoryConfig.

Args:
content (MemoryContent): The memory content to store within Redis.
Expand All @@ -241,7 +255,7 @@ async def add(self, content: MemoryContent, cancellation_token: CancellationToke
metadata = {"mime_type": mime_type}
metadata.update(content.metadata if content.metadata else {})
self.message_history.add_message(
{"role": "user", "content": memory_content, "tool_call_id": serialize(metadata)} # type: ignore[reportArgumentType]
{"role": "user", "content": memory_content, "metadata": serialize(metadata)} # type: ignore[reportArgumentType]
)

async def query(
Expand All @@ -258,6 +272,7 @@ async def query(
top_k (int): The maximum number of relevant memories to include. Defaults to 10.
distance_threshold (float): The maximum distance in vector space to consider a memory
semantically similar when performining cosine similarity search. Defaults to 0.7.
sequential (bool): Ignore semantic similarity and return the top_k most recent memories.

Args:
query (str | MemoryContent): query to perform vector similarity search with. If a
Expand All @@ -270,34 +285,46 @@ async def query(
Returns:
memoryQueryResult: Object containing memories relevant to the provided query.
"""
# get the query string, or raise an error for unsupported MemoryContent types
if isinstance(query, str):
prompt = query
elif isinstance(query, MemoryContent):
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
prompt = str(query.content)
elif query.mime_type == MemoryMimeType.JSON:
prompt = serialize(query.content)
else:
raise NotImplementedError(
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
)
else:
raise TypeError("'query' must be either a string or MemoryContent")

top_k = kwargs.pop("top_k", self.config.top_k)
distance_threshold = kwargs.pop("distance_threshold", self.config.distance_threshold)

results = self.message_history.get_relevant(
prompt=prompt, # type: ignore[reportArgumentType]
top_k=top_k,
distance_threshold=distance_threshold,
raw=False,
)
# if sequential memory is requested skip prompt creation
sequential = bool(kwargs.pop("sequential", self.config.sequential))
if self.config.sequential and not sequential:
raise ValueError(
"Non-sequential queries cannot be run with an underlying sequential RedisMemory. Set sequential=False in RedisMemoryConfig to enable semantic memory querying."
)
elif sequential or self.config.sequential:
results = self.message_history.get_recent(
top_k=top_k,
raw=False,
)
else:
# get the query string, or raise an error for unsupported MemoryContent types
if isinstance(query, str):
prompt = query
elif isinstance(query, MemoryContent):
if query.mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
prompt = str(query.content)
elif query.mime_type == MemoryMimeType.JSON:
prompt = serialize(query.content)
else:
raise NotImplementedError(
f"Error: {query.mime_type} is not supported. Only MemoryMimeType.TEXT, MemoryMimeType.JSON, MemoryMimeType.MARKDOWN are currently supported."
)
else:
raise TypeError("'query' must be either a string or MemoryContent")

results = self.message_history.get_relevant( # type: ignore
prompt=prompt, # type: ignore[reportArgumentType]
top_k=top_k,
distance_threshold=distance_threshold,
raw=False,
)

memories: List[MemoryContent] = []
for result in results:
metadata = deserialize(result["tool_call_id"]) # type: ignore[reportArgumentType]
for result in results: # type: ignore[reportUnkownVariableType]
metadata = deserialize(result["metadata"]) # type: ignore[reportArgumentType]
mime_type = MemoryMimeType(metadata.pop("mime_type"))
if mime_type in (MemoryMimeType.TEXT, MemoryMimeType.MARKDOWN):
memory_content = result["content"] # type: ignore[reportArgumentType]
Expand Down
146 changes: 140 additions & 6 deletions python/packages/autogen-ext/tests/memory/test_redis_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def test_redis_memory_query_with_mock() -> None:
memory = RedisMemory(config=config)

mock_history.get_relevant.return_value = [
{"content": "test content", "tool_call_id": '{"foo": "bar", "mime_type": "text/plain"}'}
{"content": "test content", "metadata": '{"foo": "bar", "mime_type": "text/plain"}'}
]
result = await memory.query("test")
assert len(result.results) == 1
Expand Down Expand Up @@ -86,13 +86,26 @@ def semantic_config() -> RedisMemoryConfig:
return RedisMemoryConfig(top_k=5, distance_threshold=0.5, model_name="sentence-transformers/all-mpnet-base-v2")


@pytest.fixture
def sequential_config() -> RedisMemoryConfig:
"""Create base configuration using semantic memory."""
return RedisMemoryConfig(top_k=5, sequential=True)


@pytest_asyncio.fixture # type: ignore[reportUntypedFunctionDecorator]
async def semantic_memory(semantic_config: RedisMemoryConfig) -> AsyncGenerator[RedisMemory]:
memory = RedisMemory(semantic_config)
yield memory
await memory.close()


@pytest_asyncio.fixture # type: ignore[reportUntypedFunctionDecorator]
async def sequential_memory(sequential_config: RedisMemoryConfig) -> AsyncGenerator[RedisMemory]:
memory = RedisMemory(sequential_config)
yield memory
await memory.close()


## UNIT TESTS ##
def test_memory_config() -> None:
default_config = RedisMemoryConfig()
Expand All @@ -104,6 +117,7 @@ def test_memory_config() -> None:
assert default_config.top_k == 10
assert default_config.distance_threshold == 0.7
assert default_config.model_name == "sentence-transformers/all-mpnet-base-v2"
assert not default_config.sequential

# test we can specify each of these values
url = "rediss://localhost:7010"
Expand Down Expand Up @@ -144,14 +158,36 @@ def test_memory_config() -> None:

@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_create_semantic_memory() -> None:
config = RedisMemoryConfig(index_name="semantic_agent")
@pytest.mark.parametrize("sequential", [True, False])
async def test_create_memory(sequential: bool) -> None:
config = RedisMemoryConfig(index_name="semantic_agent", sequential=sequential)
memory = RedisMemory(config=config)

assert memory.message_history is not None
await memory.close()


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_specify_vectorizer() -> None:
config = RedisMemoryConfig(index_name="semantic_agent", model_name="redis/langcache-embed-v1")
memory = RedisMemory(config=config)
assert memory.message_history._vectorizer.dims == 768 # type: ignore[reportPrivateUsage]
await memory.close()

config = RedisMemoryConfig(
index_name="semantic_agent", model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
)
memory = RedisMemory(config=config)
assert memory.message_history._vectorizer.dims == 384 # type: ignore[reportPrivateUsage]
await memory.close()

# throw an error if a non-existant model name is passed
config = RedisMemoryConfig(index_name="semantic_agent", model_name="not-a-real-model")
with pytest.raises(OSError):
memory = RedisMemory(config=config)


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_update_context(semantic_memory: RedisMemory) -> None:
Expand Down Expand Up @@ -223,7 +259,7 @@ async def test_update_context(semantic_memory: RedisMemory) -> None:

@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_add_and_query(semantic_memory: RedisMemory) -> None:
async def test_add_and_query_with_string(semantic_memory: RedisMemory) -> None:
content_1 = MemoryContent(
content="I enjoy fruits like apples, oranges, and bananas.", mime_type=MemoryMimeType.TEXT, metadata={}
)
Expand Down Expand Up @@ -251,6 +287,38 @@ async def test_add_and_query(semantic_memory: RedisMemory) -> None:
assert memories.results[1].metadata == {"description": "additional info"}


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_add_and_query_with_memory_content(semantic_memory: RedisMemory) -> None:
content_1 = MemoryContent(
content="I enjoy fruits like apples, oranges, and bananas.", mime_type=MemoryMimeType.TEXT, metadata={}
)
await semantic_memory.add(content_1)

# find matches with a similar query
memories = await semantic_memory.query(MemoryContent(content="Fruits that I like.", mime_type=MemoryMimeType.TEXT))
assert len(memories.results) == 1

# don't return anything for dissimilar queries
no_memories = await semantic_memory.query(
MemoryContent(content="The king of England", mime_type=MemoryMimeType.TEXT)
)
assert len(no_memories.results) == 0

# match multiple relevant memories
content_2 = MemoryContent(
content="I also like mangos and pineapples.",
mime_type=MemoryMimeType.TEXT,
metadata={"description": "additional info"},
)
await semantic_memory.add(content_2)

memories = await semantic_memory.query(MemoryContent(content="Fruits that I like.", mime_type=MemoryMimeType.TEXT))
assert len(memories.results) == 2
assert memories.results[0].metadata == {}
assert memories.results[1].metadata == {"description": "additional info"}


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_clear(semantic_memory: RedisMemory) -> None:
Expand Down Expand Up @@ -283,9 +351,16 @@ async def test_close(semantic_config: RedisMemoryConfig) -> None:
## INTEGRATION TESTS ##
@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_basic_workflow(semantic_config: RedisMemoryConfig) -> None:
@pytest.mark.parametrize("config_type", ["sequential", "semantic"])
async def test_basic_workflow(config_type: str) -> None:
"""Test basic memory operations with semantic memory."""
memory = RedisMemory(config=semantic_config)
if config_type == "sequential":
config = RedisMemoryConfig(top_k=5, sequential=True)
else:
config = RedisMemoryConfig(
top_k=5, distance_threshold=0.5, model_name="sentence-transformers/all-mpnet-base-v2"
)
memory = RedisMemory(config=config)
await memory.clear()

await memory.add(
Expand Down Expand Up @@ -318,6 +393,11 @@ async def test_text_memory_type(semantic_memory: RedisMemory) -> None:
assert len(results.results) > 0
assert any("Simple text content" in str(r.content) for r in results.results)

# Query for text content with a MemoryContent object
results = await semantic_memory.query(MemoryContent(content="simple text content", mime_type=MemoryMimeType.TEXT))
assert len(results.results) > 0
assert any("Simple text content" in str(r.content) for r in results.results)


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
Expand Down Expand Up @@ -419,3 +499,57 @@ async def test_query_arguments(semantic_memory: RedisMemory) -> None:
# limit search to only close matches
results = await semantic_memory.query("my favorite fruit are what?", distance_threshold=0.2)
assert len(results.results) == 1

# get memories based on recency instead of relevance
results = await semantic_memory.query("fast sports cars", sequential=True)
assert len(results.results) == 3

# setting 'sequential' to False results in default behaviour
results = await semantic_memory.query("my favorite fruit are what?", sequential=False)
assert len(results.results) == 3


@pytest.mark.asyncio
@pytest.mark.skipif(not redis_available(), reason="Redis instance not available locally")
async def test_sequential_memory_workflow(sequential_memory: RedisMemory) -> None:
await sequential_memory.clear()

await sequential_memory.add(MemoryContent(content="my favorite fruit are apples", mime_type=MemoryMimeType.TEXT))
await sequential_memory.add(
MemoryContent(
content="I read the encyclopedia britanica and my favorite section was on the Napoleonic Wars.",
mime_type=MemoryMimeType.TEXT,
)
)
await sequential_memory.add(
MemoryContent(content="Sharks have no idea that camels exist.", mime_type=MemoryMimeType.TEXT)
)
await sequential_memory.add(
MemoryContent(
content="Python is a popular programming language used for machine learning and AI applications.",
mime_type=MemoryMimeType.TEXT,
)
)
await sequential_memory.add(
MemoryContent(content="Fifth random and unrelated sentence", mime_type=MemoryMimeType.TEXT)
)

# default search returns last 5 memories
results = await sequential_memory.query("what fruits do I like?")
assert len(results.results) == 5

# limit search to 2 results
results = await sequential_memory.query("what fruits do I like?", top_k=2)
assert len(results.results) == 2

# sequential memory does not consider semantic similarity
results = await sequential_memory.query("How do I make peanut butter sandwiches?")
assert len(results.results) == 5

# seting 'sequential' to True in query method is redundant
results = await sequential_memory.query("fast sports cars", sequential=True)
assert len(results.results) == 5

# setting 'sequential' to False with a Sequential memory object raises an error
with pytest.raises(ValueError):
_ = await sequential_memory.query("my favorite fruit are what?", sequential=False)
Loading
Loading