Skip to content

Commit dfc9af3

Browse files
authored
Merge pull request #1013 from julep-ai/f/add-mmr-pg
feat(agents-api): added mmr to chat
2 parents 63894e4 + be69c60 commit dfc9af3

27 files changed

+648
-472
lines changed

Diff for: agents-api/agents_api/activities/container.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
1+
from typing import Any
2+
3+
from aiobotocore.client import AioBaseClient
4+
from asyncpg.pool import Pool
5+
6+
17
class State:
2-
pass
8+
postgres_pool: Pool | None
9+
s3_client: AioBaseClient | None
10+
11+
def __init__(self):
12+
self.postgres_pool = None
13+
self.s3_client = None
14+
15+
def __setattr__(self, name: str, value: Any) -> None:
16+
super().__setattr__(name, value)
317

418

519
class Container:

Diff for: agents-api/agents_api/activities/execute_integration.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..common.exceptions.tools import IntegrationExecutionException
1010
from ..common.protocol.tasks import ExecutionInput, StepContext
1111
from ..env import testing
12-
from ..queries.tools import get_tool_args_from_metadata
12+
from ..queries import tools
1313
from .container import container
1414

1515

@@ -28,17 +28,22 @@ async def execute_integration(
2828

2929
developer_id = context.execution_input.developer_id
3030
agent_id = context.execution_input.agent.id
31+
32+
if context.execution_input.task is None:
33+
msg = "Task cannot be None in execution_input"
34+
raise ValueError(msg)
35+
3136
task_id = context.execution_input.task.id
3237

33-
merged_tool_args = await get_tool_args_from_metadata(
38+
merged_tool_args = await tools.get_tool_args_from_metadata(
3439
developer_id=developer_id,
3540
agent_id=agent_id,
3641
task_id=task_id,
3742
arg_type="args",
3843
connection_pool=container.state.postgres_pool,
3944
)
4045

41-
merged_tool_setup = await get_tool_args_from_metadata(
46+
merged_tool_setup = await tools.get_tool_args_from_metadata(
4247
developer_id=developer_id,
4348
agent_id=agent_id,
4449
task_id=task_id,

Diff for: agents-api/agents_api/activities/execute_system.py

+3-7
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
)
2424
from ..common.protocol.tasks import ExecutionInput, StepContext
2525
from ..env import testing
26-
from ..queries.developers import get_developer
26+
from ..queries import developers
2727
from .container import container
2828
from .utils import get_handler
2929

@@ -78,14 +78,10 @@ async def execute_system(
7878
# Handle special cases for doc operations
7979
if system.operation == "create" and system.subresource == "doc":
8080
arguments["x_developer_id"] = arguments.pop("developer_id")
81-
bg_runner = BackgroundTasks()
82-
res = await handler(
81+
return await handler(
8382
data=CreateDocRequest(**arguments.pop("data")),
84-
background_tasks=bg_runner,
8583
**arguments,
8684
)
87-
await bg_runner()
88-
return res
8985

9086
# Handle search operations
9187
if system.operation == "search" and system.subresource == "doc":
@@ -95,7 +91,7 @@ async def execute_system(
9591

9692
# Handle chat operations
9793
if system.operation == "chat" and system.resource == "session":
98-
developer = await get_developer(
94+
developer = await developers.get_developer(
9995
developer_id=arguments["developer_id"],
10096
connection_pool=container.state.postgres_pool,
10197
)

Diff for: agents-api/agents_api/activities/task_steps/base_evaluate.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -42,14 +42,7 @@ def _recursive_evaluate(expr, evaluator: SimpleEval):
4242
except Exception as e:
4343
if activity.in_activity():
4444
evaluate_error = EvaluateError(e, expr, evaluator.names)
45-
46-
variables_accessed = {
47-
name: value for name, value in evaluator.names.items() if name in expr
48-
}
49-
50-
activity.logger.error(
51-
f"Error in base_evaluate: {evaluate_error}\nVariables accessed: {variables_accessed}"
52-
)
45+
activity.logger.error(f"Error in base_evaluate: {evaluate_error}\n")
5346
raise evaluate_error from e
5447
elif isinstance(expr, list):
5548
return [_recursive_evaluate(e, evaluator) for e in expr]

Diff for: agents-api/agents_api/activities/task_steps/transition_step.py

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ async def transition_step(
3939
msg = "Expected ExecutionInput type for context.execution_input"
4040
raise TypeError(msg)
4141

42+
if not context.execution_input.execution:
43+
msg = "Execution is required in execution_input"
44+
raise ValueError(msg)
45+
4246
# Create transition
4347
try:
4448
transition = await create_execution_transition(

Diff for: agents-api/agents_api/activities/task_steps/yield_step.py

+10
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,20 @@ async def yield_step(context: StepContext) -> StepOutcome:
1616
msg = "Expected ExecutionInput type for context.execution_input"
1717
raise TypeError(msg)
1818

19+
# Add validation for task
20+
if not context.execution_input.task:
21+
msg = "Task is required in execution_input"
22+
raise ValueError(msg)
23+
1924
all_workflows = context.execution_input.task.workflows
2025
workflow = context.current_step.workflow
2126
exprs = context.current_step.arguments
2227

28+
# Validate workflows exists
29+
if not all_workflows:
30+
msg = "No workflows found in task"
31+
raise ValueError(msg)
32+
2333
assert workflow in [wf.name for wf in all_workflows], (
2434
f"Workflow {workflow} not found in task"
2535
)

Diff for: agents-api/agents_api/activities/utils.py

+99-5
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,60 @@ def safe_range(*args):
4141
return result
4242

4343

44-
def safe_json_loads(s: str):
44+
@beartype
45+
def safe_json_loads(s: str) -> Any:
46+
"""
47+
Safely load a JSON string with size limits.
48+
49+
Args:
50+
s: JSON string to parse
51+
52+
Returns:
53+
Parsed JSON data
54+
55+
Raises:
56+
ValueError: If string exceeds size limit
57+
"""
4558
if len(s) > MAX_STRING_LENGTH:
4659
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
4760
raise ValueError(msg)
4861
return json.loads(s)
4962

5063

51-
def safe_yaml_load(s: str):
64+
@beartype
65+
def safe_yaml_load(s: str) -> Any:
66+
"""
67+
Safely load a YAML string with size limits.
68+
69+
Args:
70+
s: YAML string to parse
71+
72+
Returns:
73+
Parsed YAML data
74+
75+
Raises:
76+
ValueError: If string exceeds size limit
77+
"""
5278
if len(s) > MAX_STRING_LENGTH:
5379
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
5480
raise ValueError(msg)
5581
return yaml.load(s)
5682

5783

84+
@beartype
5885
def safe_base64_decode(s: str) -> str:
86+
"""
87+
Safely decode a base64 string with size limits.
88+
89+
Args:
90+
s: Base64 string to decode
91+
92+
Returns:
93+
Decoded UTF-8 string
94+
95+
Raises:
96+
ValueError: If string exceeds size limit or is invalid base64
97+
"""
5998
if len(s) > MAX_STRING_LENGTH:
6099
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
61100
raise ValueError(msg)
@@ -66,21 +105,66 @@ def safe_base64_decode(s: str) -> str:
66105
raise ValueError(msg)
67106

68107

108+
@beartype
69109
def safe_base64_encode(s: str) -> str:
110+
"""
111+
Safely encode a string to base64 with size limits.
112+
113+
Args:
114+
s: String to encode
115+
116+
Returns:
117+
Base64 encoded string
118+
119+
Raises:
120+
ValueError: If string exceeds size limit
121+
"""
70122
if len(s) > MAX_STRING_LENGTH:
71123
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
72124
raise ValueError(msg)
73125
return base64.b64encode(s.encode("utf-8")).decode("utf-8")
74126

75127

76-
def safe_random_choice(seq):
128+
@beartype
129+
def safe_random_choice(seq: list[Any] | tuple[Any, ...] | str) -> Any:
130+
"""
131+
Safely choose a random element from a sequence with size limits.
132+
133+
Args:
134+
seq: A sequence (list, tuple, or string) to choose from
135+
136+
Returns:
137+
A randomly selected element
138+
139+
Raises:
140+
ValueError: If sequence exceeds size limit
141+
TypeError: If input is not a valid sequence type
142+
"""
77143
if len(seq) > MAX_COLLECTION_SIZE:
78144
msg = f"Sequence exceeds maximum size of {MAX_COLLECTION_SIZE}"
79145
raise ValueError(msg)
80146
return random.choice(seq)
81147

82148

83-
def safe_random_sample(population, k):
149+
@beartype
150+
def safe_random_sample(population: list[T] | tuple[T, ...] | str, k: int) -> list[T]:
151+
"""
152+
Safely sample k elements from a population with size limits.
153+
154+
Args:
155+
population: A sequence to sample from
156+
k: Number of elements to sample
157+
158+
Returns:
159+
A list containing k randomly selected elements
160+
161+
Raises:
162+
ValueError: If population/sample size exceeds limits
163+
TypeError: If input is not a valid sequence type
164+
"""
165+
if not isinstance(population, list | tuple | str):
166+
msg = "Expected a sequence (list, tuple, or string)"
167+
raise TypeError(msg)
84168
if len(population) > MAX_COLLECTION_SIZE:
85169
msg = f"Population exceeds maximum size of {MAX_COLLECTION_SIZE}"
86170
raise ValueError(msg)
@@ -93,9 +177,19 @@ def safe_random_sample(population, k):
93177
return random.sample(population, k)
94178

95179

180+
@beartype
96181
def chunk_doc(string: str) -> list[str]:
97182
"""
98183
Chunk a string into sentences.
184+
185+
Args:
186+
string: The text to chunk into sentences
187+
188+
Returns:
189+
A list of sentence chunks
190+
191+
Raises:
192+
ValueError: If string exceeds size limit
99193
"""
100194
if len(string) > MAX_STRING_LENGTH:
101195
msg = f"String exceeds maximum length of {MAX_STRING_LENGTH}"
@@ -397,8 +491,8 @@ def get_handler(system: SystemDef) -> Callable:
397491
from ..queries.agents.update_agent import update_agent as update_agent_query
398492
from ..queries.docs.delete_doc import delete_doc as delete_doc_query
399493
from ..queries.docs.list_docs import list_docs as list_docs_query
494+
from ..queries.entries.get_history import get_history as get_history_query
400495
from ..queries.sessions.create_session import create_session as create_session_query
401-
from ..queries.sessions.delete_session import delete_session as delete_session_query
402496
from ..queries.sessions.get_session import get_session as get_session_query
403497
from ..queries.sessions.list_sessions import list_sessions as list_sessions_query
404498
from ..queries.sessions.update_session import update_session as update_session_query

Diff for: agents-api/agents_api/app.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
22
from contextlib import asynccontextmanager
3-
from typing import Any, Protocol
3+
from typing import Protocol
44

5+
from aiobotocore.client import AioBaseClient
56
from aiobotocore.session import get_session
7+
from asyncpg.pool import Pool
68
from fastapi import APIRouter, FastAPI
79
from prometheus_fastapi_instrumentator import Instrumentator
810
from scalar_fastapi import get_scalar_api_reference
@@ -11,22 +13,23 @@
1113
from .env import api_prefix, hostname, protocol, public_port
1214

1315

14-
class Assignable(Protocol):
15-
def __setattr__(self, name: str, value: Any) -> None: ...
16+
class State(Protocol):
17+
postgres_pool: Pool | None
18+
s3_client: AioBaseClient | None
1619

1720

1821
class ObjectWithState(Protocol):
19-
state: Assignable
22+
state: State
2023

2124

2225
# TODO: This currently doesn't use env.py, we should move to using them
2326
@asynccontextmanager
24-
async def lifespan(*containers: list[FastAPI | ObjectWithState]):
27+
async def lifespan(*containers: FastAPI | ObjectWithState):
2528
# INIT POSTGRES #
2629
pg_dsn = os.environ.get("PG_DSN")
2730

2831
for container in containers:
29-
if not getattr(container.state, "postgres_pool", None):
32+
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
3033
container.state.postgres_pool = await create_db_pool(pg_dsn)
3134

3235
# INIT S3 #
@@ -35,7 +38,7 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
3538
s3_endpoint = os.environ.get("S3_ENDPOINT")
3639

3740
for container in containers:
38-
if not getattr(container.state, "s3_client", None):
41+
if hasattr(container, "state") and not getattr(container.state, "s3_client", None):
3942
session = get_session()
4043
container.state.s3_client = await session.create_client(
4144
"s3",
@@ -49,14 +52,18 @@ async def lifespan(*containers: list[FastAPI | ObjectWithState]):
4952
finally:
5053
# CLOSE POSTGRES #
5154
for container in containers:
52-
if getattr(container.state, "postgres_pool", None):
53-
await container.state.postgres_pool.close()
55+
if hasattr(container, "state") and getattr(container.state, "postgres_pool", None):
56+
pool = getattr(container.state, "postgres_pool", None)
57+
if pool:
58+
await pool.close()
5459
container.state.postgres_pool = None
5560

5661
# CLOSE S3 #
5762
for container in containers:
58-
if getattr(container.state, "s3_client", None):
59-
await container.state.s3_client.close()
63+
if hasattr(container, "state") and getattr(container.state, "s3_client", None):
64+
s3_client = getattr(container.state, "s3_client", None)
65+
if s3_client:
66+
await s3_client.close()
6067
container.state.s3_client = None
6168

6269

Diff for: agents-api/agents_api/clients/litellm.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,8 @@ async def aembedding(
105105
embedding_list: list[dict[Literal["embedding"], list[float]]] = response.data
106106

107107
# Truncate the embedding to the specified dimensions
108-
embedding_list = [
108+
return [
109109
item["embedding"][:dimensions]
110110
for item in embedding_list
111111
if len(item["embedding"]) >= dimensions
112112
]
113-
114-
return embedding_list

Diff for: agents-api/agents_api/clients/temporal.py

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ async def run_task_execution_workflow(
9090
):
9191
from ..workflows.task_execution import TaskExecutionWorkflow
9292

93+
if execution_input.execution is None:
94+
msg = "execution_input.execution cannot be None"
95+
raise ValueError(msg)
96+
9397
start: TransitionTarget = start or TransitionTarget(workflow="main", step=0)
9498

9599
client = client or (await get_client())

0 commit comments

Comments
 (0)