Skip to content

Commit af2b03b

Browse files
Merge pull request #1086 from julep-ai/f/task-execution-workflow-tests
Add more workflow tests
2 parents 1ae58c9 + 0cc1614 commit af2b03b

File tree

2 files changed

+194
-12
lines changed

2 files changed

+194
-12
lines changed

agents-api/agents_api/workflows/task_execution/__init__.py

+3-11
Original file line numberDiff line numberDiff line change
@@ -331,19 +331,11 @@ async def _handle_PromptStep(
331331
step: PromptStep,
332332
):
333333
message = self.outcome.output
334-
if (
335-
step.unwrap
336-
or (not step.unwrap and not step.auto_run_tools)
337-
or (not step.unwrap and message["choices"][0]["finish_reason"] != "tool_calls")
338-
):
339-
workflow.logger.debug(f"Prompt step: Received response: {message}")
340-
return PartialTransition(output=message)
341-
342334
choice = message["choices"][0]
343335
finish_reason = choice["finish_reason"]
344-
345-
if not (step.auto_run_tools and not step.unwrap and finish_reason == "tool_calls"):
346-
return None
336+
if step.unwrap or not step.auto_run_tools or finish_reason != "tool_calls":
337+
workflow.logger.debug(f"Prompt step: Received response: {message}")
338+
return PartialTransition(output=message)
347339

348340
tool_calls_input = choice["message"]["tool_calls"]
349341
input_type = tool_calls_input[0]["type"]

agents-api/tests/test_task_execution_workflow.py

+191-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid
22
from datetime import timedelta
3-
from unittest.mock import Mock, patch
3+
from unittest.mock import Mock, call, patch
44

55
from agents_api.activities import task_steps
66
from agents_api.activities.execute_api_call import execute_api_call
@@ -12,6 +12,8 @@
1212
BaseIntegrationDef,
1313
CaseThen,
1414
GetStep,
15+
PromptItem,
16+
PromptStep,
1517
SwitchStep,
1618
SystemDef,
1719
TaskSpecDef,
@@ -475,3 +477,191 @@ async def _():
475477
)
476478
is None
477479
)
480+
481+
482+
@test("task execution workflow: handle prompt step, unwrap is True")
483+
async def _():
484+
wf = TaskExecutionWorkflow()
485+
step = PromptStep(prompt="hi there", unwrap=True)
486+
execution_input = ExecutionInput(
487+
developer_id=uuid.uuid4(),
488+
agent=Agent(
489+
id=uuid.uuid4(),
490+
created_at=utcnow(),
491+
updated_at=utcnow(),
492+
name="agent1",
493+
),
494+
agent_tools=[],
495+
arguments={},
496+
task=TaskSpecDef(
497+
name="task1",
498+
tools=[],
499+
workflows=[Workflow(name="main", steps=[step])],
500+
),
501+
)
502+
context = StepContext(
503+
execution_input=execution_input,
504+
current_input="value 1",
505+
cursor=TransitionTarget(
506+
workflow="main",
507+
step=0,
508+
),
509+
)
510+
message = {"choices": [{"finish_reason": "stop"}]}
511+
outcome = StepOutcome(output=message)
512+
wf.context = context
513+
wf.outcome = outcome
514+
with patch("agents_api.workflows.task_execution.workflow") as workflow:
515+
workflow.logger = Mock()
516+
workflow.execute_activity.return_value = "activity"
517+
518+
assert await wf.handle_step(step=step) == PartialTransition(output=message)
519+
workflow.execute_activity.assert_not_called()
520+
521+
522+
@test("task execution workflow: handle prompt step, unwrap is False, autorun tools is False")
523+
async def _():
524+
wf = TaskExecutionWorkflow()
525+
step = PromptStep(prompt="hi there", unwrap=False, auto_run_tools=False)
526+
execution_input = ExecutionInput(
527+
developer_id=uuid.uuid4(),
528+
agent=Agent(
529+
id=uuid.uuid4(),
530+
created_at=utcnow(),
531+
updated_at=utcnow(),
532+
name="agent1",
533+
),
534+
agent_tools=[],
535+
arguments={},
536+
task=TaskSpecDef(
537+
name="task1",
538+
tools=[],
539+
workflows=[Workflow(name="main", steps=[step])],
540+
),
541+
)
542+
context = StepContext(
543+
execution_input=execution_input,
544+
current_input="value 1",
545+
cursor=TransitionTarget(
546+
workflow="main",
547+
step=0,
548+
),
549+
)
550+
message = {"choices": [{"finish_reason": "stop"}]}
551+
outcome = StepOutcome(output=message)
552+
wf.context = context
553+
wf.outcome = outcome
554+
with patch("agents_api.workflows.task_execution.workflow") as workflow:
555+
workflow.logger = Mock()
556+
workflow.execute_activity.return_value = "activity"
557+
558+
assert await wf.handle_step(step=step) == PartialTransition(output=message)
559+
workflow.execute_activity.assert_not_called()
560+
561+
562+
@test(
563+
"task execution workflow: handle prompt step, unwrap is False, finish reason is not tool_calls"
564+
)
565+
async def _():
566+
wf = TaskExecutionWorkflow()
567+
step = PromptStep(prompt="hi there", unwrap=False)
568+
execution_input = ExecutionInput(
569+
developer_id=uuid.uuid4(),
570+
agent=Agent(
571+
id=uuid.uuid4(),
572+
created_at=utcnow(),
573+
updated_at=utcnow(),
574+
name="agent1",
575+
),
576+
agent_tools=[],
577+
arguments={},
578+
task=TaskSpecDef(
579+
name="task1",
580+
tools=[],
581+
workflows=[Workflow(name="main", steps=[step])],
582+
),
583+
)
584+
context = StepContext(
585+
execution_input=execution_input,
586+
current_input="value 1",
587+
cursor=TransitionTarget(
588+
workflow="main",
589+
step=0,
590+
),
591+
)
592+
message = {"choices": [{"finish_reason": "stop"}]}
593+
outcome = StepOutcome(output=message)
594+
wf.context = context
595+
wf.outcome = outcome
596+
with patch("agents_api.workflows.task_execution.workflow") as workflow:
597+
workflow.logger = Mock()
598+
workflow.execute_activity.return_value = "activity"
599+
600+
assert await wf.handle_step(step=step) == PartialTransition(output=message)
601+
workflow.execute_activity.assert_not_called()
602+
603+
604+
@test("task execution workflow: handle prompt step, function call")
605+
async def _():
606+
async def _resp():
607+
return StepOutcome(output="function_call")
608+
609+
wf = TaskExecutionWorkflow()
610+
step = PromptStep(prompt=[PromptItem(content="hi there", role="user")])
611+
execution_input = ExecutionInput(
612+
developer_id=uuid.uuid4(),
613+
agent=Agent(
614+
id=uuid.uuid4(),
615+
created_at=utcnow(),
616+
updated_at=utcnow(),
617+
name="agent1",
618+
),
619+
agent_tools=[],
620+
arguments={},
621+
task=TaskSpecDef(
622+
name="task1",
623+
tools=[],
624+
workflows=[Workflow(name="main", steps=[step])],
625+
),
626+
)
627+
context = StepContext(
628+
execution_input=execution_input,
629+
current_input="value 1",
630+
cursor=TransitionTarget(
631+
workflow="main",
632+
step=0,
633+
),
634+
)
635+
message = {
636+
"choices": [
637+
{"finish_reason": "tool_calls", "message": {"tool_calls": [{"type": "function"}]}}
638+
]
639+
}
640+
outcome = StepOutcome(output=message)
641+
wf.context = context
642+
wf.outcome = outcome
643+
with patch("agents_api.workflows.task_execution.workflow") as workflow:
644+
workflow.logger = Mock()
645+
workflow.execute_activity.side_effect = [_resp(), _resp()]
646+
647+
assert await wf.handle_step(step=step) == PartialTransition(
648+
output="function_call", type="resume"
649+
)
650+
workflow.execute_activity.assert_has_calls([
651+
call(
652+
task_steps.raise_complete_async,
653+
args=[context, [{"type": "function"}]],
654+
schedule_to_close_timeout=timedelta(days=31),
655+
retry_policy=DEFAULT_RETRY_POLICY,
656+
heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout),
657+
),
658+
call(
659+
task_steps.prompt_step,
660+
context,
661+
schedule_to_close_timeout=timedelta(
662+
seconds=30 if debug or testing else temporal_schedule_to_close_timeout
663+
),
664+
retry_policy=DEFAULT_RETRY_POLICY,
665+
heartbeat_timeout=timedelta(seconds=temporal_heartbeat_timeout),
666+
),
667+
])

0 commit comments

Comments
 (0)