Skip to content

Commit 6f1521c

Browse files
committed
chore(agents-api): misc validations
1 parent 5608d8e commit 6f1521c

File tree

9 files changed

+177
-16
lines changed

9 files changed

+177
-16
lines changed

Diff for: agents-api/agents_api/activities/container.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
1+
from typing import Any
2+
3+
from aiobotocore.client import AioBaseClient
4+
from asyncpg.pool import Pool
5+
6+
17
class State:
2-
pass
8+
postgres_pool: Pool | None
9+
s3_client: AioBaseClient | None
10+
11+
def __init__(self):
12+
self.postgres_pool = None
13+
self.s3_client = None
14+
15+
def __setattr__(self, name: str, value: Any) -> None:
16+
super().__setattr__(name, value)
317

418

519
class Container:

Diff for: agents-api/agents_api/activities/execute_integration.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..common.exceptions.tools import IntegrationExecutionException
1010
from ..common.protocol.tasks import ExecutionInput, StepContext
1111
from ..env import testing
12-
from ..queries.tools import get_tool_args_from_metadata
12+
from ..queries import tools
1313
from .container import container
1414

1515

@@ -28,17 +28,22 @@ async def execute_integration(
2828

2929
developer_id = context.execution_input.developer_id
3030
agent_id = context.execution_input.agent.id
31+
32+
if context.execution_input.task is None:
33+
msg = "Task cannot be None in execution_input"
34+
raise ValueError(msg)
35+
3136
task_id = context.execution_input.task.id
3237

33-
merged_tool_args = await get_tool_args_from_metadata(
38+
merged_tool_args = await tools.get_tool_args_from_metadata(
3439
developer_id=developer_id,
3540
agent_id=agent_id,
3641
task_id=task_id,
3742
arg_type="args",
3843
connection_pool=container.state.postgres_pool,
3944
)
4045

41-
merged_tool_setup = await get_tool_args_from_metadata(
46+
merged_tool_setup = await tools.get_tool_args_from_metadata(
4247
developer_id=developer_id,
4348
agent_id=agent_id,
4449
task_id=task_id,

Diff for: agents-api/agents_api/activities/execute_system.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from ..common.protocol.tasks import ExecutionInput, StepContext
2525
from ..env import testing
26-
from ..queries.developers import get_developer
26+
from ..queries import developers
2727
from .container import container
2828
from .utils import get_handler
2929

@@ -95,7 +95,7 @@ async def execute_system(
9595

9696
# Handle chat operations
9797
if system.operation == "chat" and system.resource == "session":
98-
developer = await get_developer(
98+
developer = await developers.get_developer(
9999
developer_id=arguments["developer_id"],
100100
connection_pool=container.state.postgres_pool,
101101
)

Diff for: agents-api/agents_api/activities/task_steps/transition_step.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ async def transition_step(
3939
msg = "Expected ExecutionInput type for context.execution_input"
4040
raise TypeError(msg)
4141

42+
if not context.execution_input.execution:
43+
msg = "Execution is required in execution_input"
44+
raise ValueError(msg)
45+
4246
# Create transition
4347
try:
4448
transition = await create_execution_transition(

Diff for: agents-api/agents_api/activities/task_steps/yield_step.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,20 @@ async def yield_step(context: StepContext) -> StepOutcome:
1616
msg = "Expected ExecutionInput type for context.execution_input"
1717
raise TypeError(msg)
1818

19+
# Add validation for task
20+
if not context.execution_input.task:
21+
msg = "Task is required in execution_input"
22+
raise ValueError(msg)
23+
1924
all_workflows = context.execution_input.task.workflows
2025
workflow = context.current_step.workflow
2126
exprs = context.current_step.arguments
2227

28+
# Validate workflows exists
29+
if not all_workflows:
30+
msg = "No workflows found in task"
31+
raise ValueError(msg)
32+
2333
assert workflow in [wf.name for wf in all_workflows], (
2434
f"Workflow {workflow} not found in task"
2535
)

Diff for: agents-api/agents_api/activities/utils.py

+98-4
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,60 @@ def safe_range(*args):
4141
return result
4242

4343

44-
def safe_json_loads(s: str):
44+
@beartype
45+
def safe_json_loads(s: str) -> Any:
46+
"""
47+
Safely load a JSON string with size limits.
48+
49+
Args:
50+
s: JSON string to parse
51+
52+
Returns:
53+
Parsed JSON data
54+
55+
Raises:
56+
ValueError: If string exceeds size limit
57+
"""
4558
if len(s) > MAX_STRING_LENGTH:
4659
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
4760
raise ValueError(msg)
4861
return json.loads(s)
4962

5063

51-
def safe_yaml_load(s: str):
64+
@beartype
65+
def safe_yaml_load(s: str) -> Any:
66+
"""
67+
Safely load a YAML string with size limits.
68+
69+
Args:
70+
s: YAML string to parse
71+
72+
Returns:
73+
Parsed YAML data
74+
75+
Raises:
76+
ValueError: If string exceeds size limit
77+
"""
5278
if len(s) > MAX_STRING_LENGTH:
5379
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
5480
raise ValueError(msg)
5581
return yaml.load(s)
5682

5783

84+
@beartype
5885
def safe_base64_decode(s: str) -> str:
86+
"""
87+
Safely decode a base64 string with size limits.
88+
89+
Args:
90+
s: Base64 string to decode
91+
92+
Returns:
93+
Decoded UTF-8 string
94+
95+
Raises:
96+
ValueError: If string exceeds size limit or is invalid base64
97+
"""
5998
if len(s) > MAX_STRING_LENGTH:
6099
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
61100
raise ValueError(msg)
@@ -66,21 +105,66 @@ def safe_base64_decode(s: str) -> str:
66105
raise ValueError(msg)
67106

68107

108+
@beartype
69109
def safe_base64_encode(s: str) -> str:
110+
"""
111+
Safely encode a string to base64 with size limits.
112+
113+
Args:
114+
s: String to encode
115+
116+
Returns:
117+
Base64 encoded string
118+
119+
Raises:
120+
ValueError: If string exceeds size limit
121+
"""
70122
if len(s) > MAX_STRING_LENGTH:
71123
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
72124
raise ValueError(msg)
73125
return base64.b64encode(s.encode("utf-8")).decode("utf-8")
74126

75127

76-
def safe_random_choice(seq):
128+
@beartype
129+
def safe_random_choice(seq: list[Any] | tuple[Any, ...] | str) -> Any:
130+
"""
131+
Safely choose a random element from a sequence with size limits.
132+
133+
Args:
134+
seq: A sequence (list, tuple, or string) to choose from
135+
136+
Returns:
137+
A randomly selected element
138+
139+
Raises:
140+
ValueError: If sequence exceeds size limit
141+
TypeError: If input is not a valid sequence type
142+
"""
77143
if len(seq) > MAX_COLLECTION_SIZE:
78144
msg = f"Sequence exceeds maximum size of {MAX_COLLECTION_SIZE}"
79145
raise ValueError(msg)
80146
return random.choice(seq)
81147

82148

83-
def safe_random_sample(population, k):
149+
@beartype
150+
def safe_random_sample(population: list[T] | tuple[T, ...] | str, k: int) -> list[T]:
151+
"""
152+
Safely sample k elements from a population with size limits.
153+
154+
Args:
155+
population: A sequence to sample from
156+
k: Number of elements to sample
157+
158+
Returns:
159+
A list containing k randomly selected elements
160+
161+
Raises:
162+
ValueError: If population/sample size exceeds limits
163+
TypeError: If input is not a valid sequence type
164+
"""
165+
if not isinstance(population, list | tuple | str):
166+
msg = "Expected a sequence (list, tuple, or string)"
167+
raise TypeError(msg)
84168
if len(population) > MAX_COLLECTION_SIZE:
85169
msg = f"Population exceeds maximum size of {MAX_COLLECTION_SIZE}"
86170
raise ValueError(msg)
@@ -93,9 +177,19 @@ def safe_random_sample(population, k):
93177
return random.sample(population, k)
94178

95179

180+
@beartype
96181
def chunk_doc(string: str) -> list[str]:
97182
"""
98183
Chunk a string into sentences.
184+
185+
Args:
186+
string: The text to chunk into sentences
187+
188+
Returns:
189+
A list of sentence chunks
190+
191+
Raises:
192+
ValueError: If string exceeds size limit
99193
"""
100194
if len(string) > MAX_STRING_LENGTH:
101195
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"

Diff for: agents-api/agents_api/clients/temporal.py

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ async def run_task_execution_workflow(
9090
):
9191
from ..workflows.task_execution import TaskExecutionWorkflow
9292

93+
if execution_input.execution is None:
94+
msg = "execution_input.execution cannot be None"
95+
raise ValueError(msg)
96+
9397
start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)
9498

9599
client = client or (await get_client())

Diff for: agents-api/agents_api/workflows/task_execution/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,10 @@ async def run(
143143
start: TransitionTarget,
144144
previous_inputs: list,
145145
) -> Any:
146+
if not execution_input.task:
147+
msg = "execution_input.task cannot be None"
148+
raise ApplicationError(msg)
149+
146150
workflow.logger.info(
147151
f"TaskExecutionWorkflow for task {execution_input.task.id}"
148152
f" [LOC {start.workflow}.{start.step}]"

0 commit comments

Comments
 (0)