Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent inputs are only set by default if not already set #91

Merged
merged 1 commit into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ix/agents/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,16 @@ async def chat_with_ai(

logger.info(f"Sending request to chain={self.chain.name} prompt={user_input}")

# auto-map user_input to input if not provided.
# work around until chat input key can be configured per chain
extra_kwargs = {}
if "input" not in user_input:
extra_kwargs["input"] = user_input

start = time.time()
try:
# Hax: copy user_input to input to support agents.
response = await chain.arun(input=user_input["user_input"], **user_input)
response = await chain.arun(**extra_kwargs, **user_input)
except: # noqa: E722
raise
finally:
Expand Down
28 changes: 28 additions & 0 deletions ix/agents/tests/test_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,34 @@ async def test_start_task(self, mock_openai, mock_embeddings):
# "total_tokens": 12,
# }

async def test_start_task_with_input(self, mock_openai, mock_embeddings):
"""
Test that if `input` is included in inputs then it will be
used instead of the default `user_input -> input` mapping.
"""
await sync_to_async(load_fixture)("node_types")
task = await sync_to_async(fake_task)()
mock_reply = await sync_to_async(fake_command_reply)(task=task)
await mock_reply.adelete()
query = TaskLogMessage.objects.filter(task=task)
count = await query.acount()
assert count == 0
agent_process = AgentProcess(task=task, agent=task.agent, chain=task.chain)

inputs = {"user_input": "hello agent 1", "input": "existing input"}
return_value = await agent_process.start(inputs)
assert return_value is True

count = await query.acount()
assert count == 2
messages = [msg async for msg in query]
think_msg = messages[0]
thought_msg = messages[1]
assert think_msg.content["type"] == "THINK"
assert think_msg.content["input"] == inputs
assert thought_msg.content["type"] == "THOUGHT"
assert isinstance(thought_msg.content["runtime"], float)


def msg_to_response(msg: TaskLogMessage):
"""utility for turning model instances back into response json"""
Expand Down