Skip to content

Commit 3d38df9

Browse files
authored
Merge pull request #1068 from julep-ai/x/continue-as-new-bug
fix(agents-api): Fix blob-store interceptor implementation
2 parents f49ea57 + 21040bb commit 3d38df9

File tree

15 files changed

+373
-84
lines changed

15 files changed

+373
-84
lines changed

agents-api/agents_api/activities/sync_items_remote.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
import asyncio
21
from typing import Any
32

43
from beartype import beartype
54
from temporalio import activity
65

7-
from ..common.protocol.remote import RemoteObject
6+
from ..worker.codec import RemoteObject
87

98

109
@beartype
1110
async def save_inputs_remote_fn(inputs: list[Any]) -> list[Any | RemoteObject]:
1211
from ..common.interceptors import offload_if_large
1312

14-
return await asyncio.gather(*[offload_if_large(input) for input in inputs])
13+
return [offload_if_large(input) for input in inputs]
1514

1615

1716
@beartype
1817
async def load_inputs_remote_fn(inputs: list[Any | RemoteObject]) -> list[Any]:
1918
from ..common.interceptors import load_if_remote
2019

21-
return await asyncio.gather(*[load_if_remote(input) for input in inputs])
20+
return [load_if_remote(input) for input in inputs]
2221

2322

2423
save_inputs_remote = activity.defn(name="save_inputs_remote")(save_inputs_remote_fn)

agents-api/agents_api/activities/task_steps/evaluate_step.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ async def evaluate_step(
2424

2525
except BaseException as e:
2626
activity.logger.error(f"Error in evaluate_step: {e}")
27-
return StepOutcome(error=str(e) or repr(e))
27+
return StepOutcome(error=str(e) or repr(e), output=None)

agents-api/agents_api/app.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
from contextlib import asynccontextmanager
33
from typing import Protocol
44

5-
from aiobotocore.client import AioBaseClient
6-
from aiobotocore.session import get_session
7-
from asyncpg.pool import Pool
8-
from fastapi import APIRouter, FastAPI
9-
from prometheus_fastapi_instrumentator import Instrumentator
10-
from scalar_fastapi import get_scalar_api_reference
5+
from temporalio import workflow
116

12-
from .clients.pg import create_db_pool
13-
from .env import api_prefix, hostname, pool_max_size, protocol, public_port
7+
with workflow.unsafe.imports_passed_through():
8+
from aiobotocore.client import AioBaseClient
9+
from aiobotocore.session import get_session
10+
from asyncpg.pool import Pool
11+
from fastapi import APIRouter, FastAPI
12+
from prometheus_fastapi_instrumentator import Instrumentator
13+
from scalar_fastapi import get_scalar_api_reference
14+
15+
from .clients.pg import create_db_pool
16+
from .env import api_prefix, hostname, pool_max_size, protocol, public_port
1417

1518

1619
class State(Protocol):
@@ -28,7 +31,7 @@ async def lifespan(container: FastAPI | ObjectWithState):
2831
# INIT POSTGRES #
2932
pg_dsn = os.environ.get("PG_DSN")
3033

31-
pool = await create_db_pool(pg_dsn, max_size=pool_max_size)
34+
pool = await create_db_pool(pg_dsn, max_size=pool_max_size, min_size=min(pool_max_size, 10))
3235

3336
if hasattr(container, "state") and not getattr(container.state, "postgres_pool", None):
3437
container.state.postgres_pool = pool

agents-api/agents_api/clients/async_s3.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,22 @@
33

44
with workflow.unsafe.imports_passed_through():
55
import botocore
6+
from aiobotocore.client import AioBaseClient
67
from async_lru import alru_cache
78
from xxhash import xxh3_64_hexdigest as xxhash_key
89

9-
from ..env import (
10-
blob_store_bucket,
11-
blob_store_cutoff_kb,
12-
)
10+
from ..env import blob_store_bucket
1311

1412

1513
@alru_cache(maxsize=1)
16-
async def setup():
14+
async def setup() -> AioBaseClient:
1715
from ..app import app
1816

19-
if not app.state.s3_client:
17+
client: AioBaseClient | None = getattr(app.state, "s3_client", None)
18+
if client is None:
2019
msg = "S3 client not initialized"
2120
raise RuntimeError(msg)
2221

23-
client = app.state.s3_client
24-
2522
try:
2623
await client.head_bucket(Bucket=blob_store_bucket)
2724
except botocore.exceptions.ClientError as e:
@@ -68,7 +65,6 @@ async def add_object(key: str, body: bytes, replace: bool = False) -> None:
6865
await client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)
6966

7067

71-
@alru_cache(maxsize=256 * 1024 // max(1, blob_store_cutoff_kb)) # 256mb in cache
7268
@beartype
7369
async def get_object(key: str) -> bytes:
7470
client = await setup()
+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import os
2+
from functools import lru_cache
3+
4+
from beartype import beartype
5+
from temporalio import workflow
6+
7+
with workflow.unsafe.imports_passed_through():
8+
import botocore
9+
from xxhash import xxh3_64_hexdigest as xxhash_key
10+
11+
from ..env import blob_store_bucket
12+
13+
14+
@lru_cache(maxsize=1)
15+
def setup():
16+
# INIT S3 #
17+
s3_access_key = os.environ.get("S3_ACCESS_KEY")
18+
s3_secret_key = os.environ.get("S3_SECRET_KEY")
19+
s3_endpoint = os.environ.get("S3_ENDPOINT")
20+
21+
session = botocore.session.Session()
22+
client = session.create_client(
23+
"s3",
24+
endpoint_url=s3_endpoint,
25+
aws_access_key_id=s3_access_key,
26+
aws_secret_access_key=s3_secret_key,
27+
config=botocore.config.Config(signature_version="s3v4", retries={"max_attempts": 3}),
28+
)
29+
30+
try:
31+
client.head_bucket(Bucket=blob_store_bucket)
32+
except botocore.exceptions.ClientError as e:
33+
if e.response["Error"]["Code"] == "404":
34+
client.create_bucket(Bucket=blob_store_bucket)
35+
else:
36+
raise e
37+
38+
return client
39+
40+
41+
@lru_cache(maxsize=1024)
42+
def list_buckets() -> list[str]:
43+
client = setup()
44+
45+
data = client.list_buckets()
46+
return [bucket["Name"] for bucket in data["Buckets"]]
47+
48+
49+
@lru_cache(maxsize=10_000)
50+
def exists(key: str) -> bool:
51+
client = setup()
52+
53+
try:
54+
client.head_object(Bucket=blob_store_bucket, Key=key)
55+
return True
56+
except botocore.exceptions.ClientError as e:
57+
if e.response["Error"]["Code"] == "404":
58+
return False
59+
raise e
60+
61+
62+
@beartype
63+
def add_object(key: str, body: bytes, replace: bool = False) -> None:
64+
client = setup()
65+
66+
if replace:
67+
client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)
68+
return
69+
70+
if exists(key):
71+
return
72+
73+
client.put_object(Bucket=blob_store_bucket, Key=key, Body=body)
74+
75+
76+
@beartype
77+
def get_object(key: str) -> bytes:
78+
client = setup()
79+
80+
response = client.get_object(Bucket=blob_store_bucket, Key=key)
81+
return response["Body"].read()
82+
83+
84+
@beartype
85+
def delete_object(key: str) -> None:
86+
client = setup()
87+
client.delete_object(Bucket=blob_store_bucket, Key=key)
88+
89+
90+
@beartype
91+
def add_object_with_hash(body: bytes, replace: bool = False) -> str:
92+
key = xxhash_key(body)
93+
add_object(key, body, replace=replace)
94+
95+
return key

agents-api/agents_api/clients/temporal.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ async def run_task_execution_workflow(
114114
execution_id_key = SearchAttributeKey.for_keyword("CustomStringField")
115115

116116
old_args = execution_input.arguments
117-
execution_input.arguments = await offload_if_large(old_args)
117+
execution_input.arguments = offload_if_large(old_args)
118118

119119
current_input: dict = current_input or execution_input.arguments
120120

0 commit comments

Comments
 (0)