Skip to content

Commit

Permalink
Merge pull request #146 from kreneskyp/agent_memory
Browse files Browse the repository at this point in the history
Agent memory
  • Loading branch information
kreneskyp authored Sep 25, 2023
2 parents 705ff66 + 45deea3 commit b961fed
Show file tree
Hide file tree
Showing 19 changed files with 181 additions and 101 deletions.
10 changes: 7 additions & 3 deletions ix/agents/tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ async def test_start_task(self, mock_openai, anode_types):
thought_msg = messages[1]
assert think_msg.content["type"] == "THINK"
assert think_msg.content["input"] == {
"user_input": "hello agent 1",
"input": "hello agent 1",
"user_input": "hello agent 1",
"question": "hello agent 1",
}
assert thought_msg.content["type"] == "THOUGHT"
Expand Down Expand Up @@ -96,7 +96,7 @@ async def test_start_task_with_input(self, mock_openai, anode_types):

inputs = {
"user_input": "hello agent 1",
"input": "existing input",
"input": "hello agent 1",
"question": "hello agent 1",
}
return_value = await agent_process.start(inputs)
Expand All @@ -108,7 +108,11 @@ async def test_start_task_with_input(self, mock_openai, anode_types):
think_msg = messages[0]
thought_msg = messages[1]
assert think_msg.content["type"] == "THINK"
assert think_msg.content["input"] == inputs
assert think_msg.content["input"] == {
"input": "hello agent 1",
"user_input": "hello agent 1",
"question": "hello agent 1",
}
assert thought_msg.content["type"] == "THOUGHT"
assert isinstance(thought_msg.content["runtime"], float)

Expand Down
35 changes: 34 additions & 1 deletion ix/chains/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools
import json
import logging
import time
import traceback
Expand Down Expand Up @@ -25,6 +27,20 @@
logger = logging.getLogger(__name__)


def log_error(func):
"""simple logging decorator to expose errors in callback methods"""

@functools.wraps(func)
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except Exception as e:
logger.error(traceback.format_exc())
raise e

return wrapper


@dataclasses.dataclass
class RunContext:
"""Context info for a single run of an llm/chain."""
Expand Down Expand Up @@ -117,6 +133,7 @@ def chat_id(self) -> str:
return None
return chat.id

@log_error
async def on_chat_model_start(
self,
serialized: Dict[str, Any],
Expand Down Expand Up @@ -146,6 +163,7 @@ async def on_llm_start(
"""Runs when an LLM model starts"""
pass

@log_error
async def on_llm_new_token(
self, token: str, parent_run_id: Optional[UUID] = None, **kwargs: Any
) -> Any:
Expand All @@ -166,6 +184,7 @@ async def on_llm_error(
) -> Any:
"""Run when LLM errors."""

@log_error
async def on_chain_start(
self,
serialized: Dict[str, Any],
Expand All @@ -180,14 +199,28 @@ async def on_chain_start(
"""Run when chain starts running."""

if not self.parent_think_msg:
# Inputs will be encoded as JSON, but if they can't be, we'll just
# serialize them as a string.
try:
json.dumps(inputs)
except TypeError:
serialized_inputs = str(inputs)
else:
serialized_inputs = inputs

self.start = time.time()
think_msg = await TaskLogMessage.objects.acreate(
task_id=self.task.id,
role="SYSTEM",
content={"type": "THINK", "input": inputs, "agent": self.agent.alias},
content={
"type": "THINK",
"input": serialized_inputs,
"agent": self.agent.alias,
},
)
self.parent_think_msg = think_msg

@log_error
async def on_chain_end(
self,
outputs: Dict[str, Any],
Expand Down
24 changes: 13 additions & 11 deletions ix/chains/fixture_src/agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from ix.chains.fixture_src.common import VERBOSE
from ix.chains.fixture_src.targets import (
PROMPT_TARGET,
TOOLS_TARGET,
LLM_TARGET,
MEMORY_TARGET,
Expand All @@ -27,12 +26,15 @@
]


OPENAI_FUNCTIONS_AGENT_CLASS_PATH = (
"ix.chains.loaders.agents.initialize_openai_functions"
)
OPENAI_FUNCTIONS_AGENT = {
"class_path": "ix.chains.loaders.agents.initialize_openai_functions",
"class_path": OPENAI_FUNCTIONS_AGENT_CLASS_PATH,
"type": "agent",
"name": "OpenAI Function Agent",
"description": "Agent that uses OpenAI's API to generate text.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET, MEMORY_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -41,7 +43,7 @@
"type": "agent",
"name": "OpenAI Multifunction Agent",
"description": "Agent that uses OpenAI's API to generate text.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET, MEMORY_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -50,7 +52,7 @@
"type": "agent",
"name": "Zero Shot React Description Agent",
"description": "Agent that generates descriptions by taking zero-shot approach using reaction information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET, MEMORY_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -59,7 +61,7 @@
"type": "agent",
"name": "React Docstore Agent",
"description": "Agent that interacts with the document store to obtain reaction-based information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -68,7 +70,7 @@
"type": "agent",
"name": "Self Ask with Search Agent",
"description": "Agent that asks itself queries and searches for answers in a given context.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -77,7 +79,7 @@
"type": "agent",
"name": "Conversational React Description Agent",
"description": "Agent that provides descriptions in a conversational manner using reaction information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -86,7 +88,7 @@
"type": "agent",
"name": "Chat Zero Shot React Description Agent",
"description": "Agent that generates descriptions in a chat-based context using a zero-shot approach and reaction information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -95,7 +97,7 @@
"type": "agent",
"name": "Chat Conversational React Description Agent",
"description": "Agent that provides descriptions in a chat-based context in a conversational manner using reaction information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand All @@ -104,7 +106,7 @@
"type": "agent",
"name": "Structured Chat Zero Shot React Description Agent",
"description": "Agent that generates descriptions in a structured chat context using a zero-shot approach and reaction information.",
"connectors": [LLM_TARGET, TOOLS_TARGET, PROMPT_TARGET, MEMORY_TARGET],
"connectors": [LLM_TARGET, TOOLS_TARGET],
"fields": EXECUTOR_BASE_FIELDS,
}

Expand Down
5 changes: 5 additions & 0 deletions ix/chains/fixture_src/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@
"type": "string",
"default": "input",
},
{
"name": "return_messages",
"type": "boolean",
"default": False,
},
]

CONVERSATION_BUFFER_MEMORY = {
Expand Down
22 changes: 21 additions & 1 deletion ix/chains/loaders/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import sys
from typing import Callable

from langchain.prompts import MessagesPlaceholder

from langchain.agents import AgentType, AgentExecutor
from langchain.agents import initialize_agent as initialize_agent_base
from langchain.agents.agent_toolkits.base import BaseToolkit
Expand All @@ -19,8 +21,26 @@ def initialize_agent(agent: AgentType, **kwargs) -> Chain:
- unpacks agent_kwargs: allows agent_kwargs to be flattened into the ChainNode config
A flattened config simplifies the UX integration such that it works with TypeAutoFields
"""
# Inject placeholders into prompt for memory if provided
placeholders = []
if memories := kwargs.get("memory", None):
if not isinstance(memories, list):
memories = [memories]
placeholders = []
for component in memories:
if not getattr(component, "return_messages", False):
raise ValueError(
f"Memory component {component} has return_messages=False. Agents require "
f"return_messages=True."
)
for memory_key in component.memory_variables:
placeholders.append(MessagesPlaceholder(variable_name=memory_key))

# Re-pack agent_kwargs__* arguments into agent_kwargs
agent_kwargs = {}
agent_kwargs = {
"extra_prompt_messages": placeholders,
}

for key, value in kwargs.items():
if key.startswith("agent_kwargs__"):
agent_kwargs[key[15:]] = value
Expand Down
57 changes: 57 additions & 0 deletions ix/chains/tests/test_config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from langchain.text_splitter import TextSplitter
from langchain.vectorstores import Redis

from ix.chains.fixture_src.agents import OPENAI_FUNCTIONS_AGENT_CLASS_PATH
from ix.chains.fixture_src.chains import CONVERSATIONAL_RETRIEVAL_CHAIN_CLASS_PATH
from ix.chains.fixture_src.document_loaders import GENERIC_LOADER_CLASS_PATH
from ix.chains.fixture_src.embeddings import OPENAI_EMBEDDINGS_CLASS_PATH
Expand Down Expand Up @@ -83,6 +84,16 @@ class TestLoadLLM:
},
}

AGENT_MEMORY = {
"class_path": "langchain.memory.ConversationBufferMemory",
"config": {
"input_key": "user_input",
"memory_key": "chat_history",
# agent requires return_messages=True
"return_messages": True,
},
}

MEMORY_WITH_SCOPE = {
"class_path": "ix.memory.artifacts.ArtifactMemory",
"config": {
Expand Down Expand Up @@ -507,6 +518,52 @@ async def test_load_agents(self, aload_chain, mock_openai, mock_google_api_key):
instance = await aload_chain(config)
assert isinstance(instance, AgentExecutor)

async def test_agent_memory(self, mock_openai, aload_chain, mock_google_api_key):
config = {
"class_path": OPENAI_FUNCTIONS_AGENT_CLASS_PATH,
"name": "tester",
"description": "test",
"config": {
"tools": [GOOGLE_SEARCH_CONFIG],
"llm": OPENAI_LLM,
"memory": AGENT_MEMORY,
},
}
executor = await aload_chain(config)
assert isinstance(executor, AgentExecutor) # sanity check

# 1. test that prompt includes placeholders
# 2. test that memory keys are correct
# 3. test that memory is loaded for agent
result = await executor.acall(inputs={"input": "foo", "user_input": "bar"})

# verify response contains memory
assert result["chat_history"][0].content == "bar"
assert result["chat_history"][1].content == "mock llm response"

# call second time to smoke test
await executor.acall(inputs={"input": "foo", "user_input": "bar"})

async def test_agent_memory_misconfigured(
self, mock_openai, aload_chain, mock_google_api_key
):
"""test agent/memory misconfigurations that should raise errors
- memory class must have `return_messages=True`
"""
config = {
"class_path": "ix.chains.loaders.agents.initialize_zero_shot_react_description",
"name": "tester",
"description": "test",
"config": {
"tools": [GOOGLE_SEARCH_CONFIG],
"llm": OPENAI_LLM,
"memory": MEMORY,
},
}
with pytest.raises(ValueError) as excinfo:
await aload_chain(config)
assert "Agents require return_messages=True" in str(excinfo.value)


TEST_DATA = Path("/var/app/test_data")
TEST_DOCUMENTS = TEST_DATA / "documents"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,6 @@
"toolkit"
],
"type": "target"
},
{
"key": "prompt",
"source_type": "prompt",
"type": "target"
},
{
"key": "memory",
"multiple": true,
"source_type": "memory",
"type": "target"
}
],
"description": "Agent that provides descriptions in a chat-based context in a conversational manner using reaction information.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,6 @@
"toolkit"
],
"type": "target"
},
{
"key": "prompt",
"source_type": "prompt",
"type": "target"
},
{
"key": "memory",
"multiple": true,
"source_type": "memory",
"type": "target"
}
],
"description": "Agent that generates descriptions in a chat-based context using a zero-shot approach and reaction information.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,6 @@
"toolkit"
],
"type": "target"
},
{
"key": "prompt",
"source_type": "prompt",
"type": "target"
},
{
"key": "memory",
"multiple": true,
"source_type": "memory",
"type": "target"
}
],
"description": "Agent that provides descriptions in a conversational manner using reaction information.",
Expand Down
Loading

0 comments on commit b961fed

Please sign in to comment.