Skip to content

Commit 9d3633a

Browse files
authored
Merge pull request #1076 from julep-ai/x/evaluate-step
fix(agents-api): allow nested dictionaries in ``EvaluateStep`` and ``SetStep``
2 parents faa767d + 48a63c8 commit 9d3633a

File tree

13 files changed

+105
-35
lines changed

13 files changed

+105
-35
lines changed

agents-api/agents_api/activities/utils.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,9 @@ def get_evaluator(
422422

423423

424424
@beartype
425-
def simple_eval_dict(exprs: dict[str, str], values: dict[str, Any]) -> dict[str, Any]:
425+
def simple_eval_dict(
426+
exprs: dict[str, str | dict[str, Any]], values: dict[str, Any]
427+
) -> dict[str, Any]:
426428
if len(exprs) > MAX_COLLECTION_SIZE:
427429
msg = f"Too many expressions (max {MAX_COLLECTION_SIZE})"
428430
raise ValueError(msg)
@@ -433,7 +435,10 @@ def simple_eval_dict(exprs: dict[str, str], values: dict[str, Any]) -> dict[str,
433435
raise ValueError(msg)
434436

435437
evaluator = get_evaluator(names=values)
436-
return {k: evaluator.eval(v) for k, v in exprs.items()}
438+
return {
439+
k: evaluator.eval(v) if isinstance(v, str) else simple_eval_dict(v, values)
440+
for k, v in exprs.items()
441+
}
437442

438443

439444
def get_handler_with_filtered_params(system: SystemDef) -> Callable:

agents-api/agents_api/autogen/Tasks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ class EvaluateStep(BaseModel):
247247
"""
248248
The label of this step for referencing it from other steps
249249
"""
250-
evaluate: dict[str, list[str] | dict[str, str] | list[dict[str, str]] | str]
250+
evaluate: dict[str, dict[str, Any] | str]
251251
"""
252252
The expression to evaluate
253253
"""
@@ -861,7 +861,7 @@ class SetStep(BaseModel):
861861
"""
862862
The label of this step for referencing it from other steps
863863
"""
864-
set: dict[str, str]
864+
set: dict[str, dict[str, Any] | str]
865865
"""
866866
The value to set
867867
"""

agents-api/agents_api/clients/litellm.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,14 @@ async def get_model_list(*, custom_api_key: str | None = None) -> list[dict]:
121121
list[dict]: A list of model information dictionaries
122122
"""
123123

124-
headers = {
125-
"accept": "application/json",
126-
"x-api-key": custom_api_key or litellm_master_key
127-
}
128-
129-
async with aiohttp.ClientSession() as session, session.get(
130-
url=f"{litellm_url}/models" if not custom_api_key else "/models",
131-
headers=headers
132-
) as response:
124+
headers = {"accept": "application/json", "x-api-key": custom_api_key or litellm_master_key}
125+
126+
async with (
127+
aiohttp.ClientSession() as session,
128+
session.get(
129+
url=f"{litellm_url}/models" if not custom_api_key else "/models", headers=headers
130+
) as response,
131+
):
133132
response.raise_for_status()
134133
data = await response.json()
135134
return data["data"]

agents-api/agents_api/routers/agents/create_agent.py

-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ async def create_agent(
1919
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
2020
data: CreateAgentRequest,
2121
) -> ResourceCreatedResponse:
22-
2322
if data.model:
2423
await validate_model(data.model)
2524

agents-api/agents_api/routers/agents/create_or_update_agent.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ async def create_or_update_agent(
2222
data: CreateOrUpdateAgentRequest,
2323
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
2424
) -> ResourceCreatedResponse:
25-
2625
if data.model:
2726
await validate_model(data.model)
2827

agents-api/agents_api/routers/agents/patch_agent.py

-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ async def patch_agent(
2222
agent_id: UUID,
2323
data: PatchAgentRequest,
2424
) -> ResourceUpdatedResponse:
25-
2625
if data.model:
2726
await validate_model(data.model)
2827

agents-api/agents_api/routers/agents/update_agent.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ async def update_agent(
2121
x_developer_id: Annotated[UUID, Depends(get_developer_id)],
2222
agent_id: UUID,
2323
data: UpdateAgentRequest,
24-
) -> ResourceUpdatedResponse:
25-
24+
) -> ResourceUpdatedResponse:
2625
if data.model:
2726
await validate_model(data.model)
2827

agents-api/agents_api/routers/utils/model_validation.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,5 @@ async def validate_model(model_name: str) -> None:
1515
if model_name not in available_models:
1616
raise HTTPException(
1717
status_code=HTTP_400_BAD_REQUEST,
18-
detail=f"Model {model_name} not available. Available models: {available_models}"
18+
detail=f"Model {model_name} not available. Available models: {available_models}",
1919
)

agents-api/tests/fixtures.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -450,9 +450,14 @@ async def test_tool(
450450

451451
@fixture(scope="global")
452452
def client(_dsn=pg_dsn):
453-
with TestClient(app=app) as client:
454-
with patch("agents_api.routers.utils.model_validation.get_model_list", return_value=SAMPLE_MODELS):
455-
yield client
453+
with (
454+
TestClient(app=app) as client,
455+
patch(
456+
"agents_api.routers.utils.model_validation.get_model_list",
457+
return_value=SAMPLE_MODELS,
458+
),
459+
):
460+
yield client
456461

457462

458463
@fixture(scope="global")
+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from ward import test, raises
2+
from agents_api.activities.utils import simple_eval_dict, MAX_STRING_LENGTH, MAX_COLLECTION_SIZE
3+
from simpleeval import NameNotDefined
4+
5+
@test("utility: simple_eval_dict - string length overflow")
6+
async def _():
7+
with raises(ValueError):
8+
simple_eval_dict({"a": "b" * (MAX_STRING_LENGTH + 1)}, {})
9+
10+
@test("utility: simple_eval_dict - collection size overflow")
11+
async def _():
12+
with raises(ValueError):
13+
simple_eval_dict({str(i): "b" for i in range(MAX_COLLECTION_SIZE + 1)}, {})
14+
15+
@test("utility: simple_eval_dict - value undefined")
16+
async def _():
17+
with raises(NameNotDefined):
18+
simple_eval_dict({"a": "b"}, {})
19+
20+
@test("utility: simple_eval_dict")
21+
async def _():
22+
exprs = {"a": {"b": "x + 5", "c": "x + 6"}}
23+
values = {"x": 5}
24+
result = simple_eval_dict(exprs, values)
25+
assert result == {"a": {"b": 10, "c": 11}}
26+

agents-api/tests/test_workflow_routes.py

+45
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,51 @@ async def _(
8080
).raise_for_status()
8181

8282

83+
@test("workflow route: evaluate step single with yaml - nested")
84+
async def _(
85+
make_request=make_request,
86+
agent=test_agent,
87+
):
88+
agent_id = str(agent.id)
89+
90+
async with patch_testing_temporal():
91+
task_data = """
92+
name: test task
93+
description: test task about
94+
input_schema:
95+
type: object
96+
additionalProperties: true
97+
98+
main:
99+
- evaluate:
100+
hello: '"world"'
101+
hello2:
102+
hello3:
103+
hello4: inputs[0]['test']
104+
"""
105+
106+
result = (
107+
make_request(
108+
method="POST",
109+
url=f"/agents/{agent_id}/tasks",
110+
content=task_data.encode("utf-8"),
111+
headers={"Content-Type": "text/yaml"},
112+
)
113+
.raise_for_status()
114+
.json()
115+
)
116+
117+
task_id = result["id"]
118+
119+
execution_data = {"input": {"test": "input"}}
120+
121+
make_request(
122+
method="POST",
123+
url=f"/tasks/{task_id}/executions",
124+
json=execution_data,
125+
).raise_for_status()
126+
127+
83128
@test("workflow route: create or update: evaluate step single with yaml")
84129
async def _(
85130
make_request=make_request,

typespec/tasks/steps.tsp

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ model EvaluateStep extends BaseWorkflowStep<"evaluate"> {
134134

135135
model EvaluateStepDef {
136136
/** The expression to evaluate */
137-
evaluate: ExpressionObject<unknown>;
137+
evaluate: Record<TypedExpression<unknown> | Record<unknown>>;
138138
}
139139

140140
model WaitForInputStep extends BaseWorkflowStep<"wait_for_input"> {
@@ -191,7 +191,7 @@ model SetStep extends BaseWorkflowStep<"set"> {
191191

192192
model SetStepDef {
193193
/** The value to set */
194-
set: Record<TypedExpression<unknown>>;
194+
set: Record<TypedExpression<unknown> | Record<unknown>>;
195195
}
196196

197197
///////////////////////

typespec/tsp-output/@typespec/openapi3/openapi-1.0.0.yaml

+5-11
Original file line numberDiff line numberDiff line change
@@ -4887,17 +4887,8 @@ components:
48874887
additionalProperties:
48884888
anyOf:
48894889
- $ref: '#/components/schemas/Common.PyExpression'
4890-
- type: array
4891-
items:
4892-
$ref: '#/components/schemas/Common.PyExpression'
48934890
- type: object
4894-
additionalProperties:
4895-
$ref: '#/components/schemas/Common.PyExpression'
4896-
- type: array
4897-
items:
4898-
type: object
4899-
additionalProperties:
4900-
$ref: '#/components/schemas/Common.PyExpression'
4891+
additionalProperties: {}
49014892
description: The expression to evaluate
49024893
allOf:
49034894
- type: object
@@ -5919,7 +5910,10 @@ components:
59195910
set:
59205911
type: object
59215912
additionalProperties:
5922-
$ref: '#/components/schemas/Common.PyExpression'
5913+
anyOf:
5914+
- $ref: '#/components/schemas/Common.PyExpression'
5915+
- type: object
5916+
additionalProperties: {}
59235917
description: The value to set
59245918
allOf:
59255919
- type: object

0 commit comments

Comments
 (0)