Skip to content

Commit

Permalink
chore: remove cozo completely, and integrate postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
Ahmad-mtos committed Dec 25, 2024
1 parent 7ab2bb1 commit 7798826
Show file tree
Hide file tree
Showing 30 changed files with 131 additions and 272 deletions.
2 changes: 1 addition & 1 deletion agents-api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ COPY . ./
ENV PYTHONUNBUFFERED=1
ENV GUNICORN_CMD_ARGS="--capture-output --enable-stdio-inheritance"

ENTRYPOINT ["uv", "run", "gunicorn", "agents_api.web:app", "-c", "gunicorn_conf.py"]
ENTRYPOINT ["uv", "run", "--offline", "--no-sync", "gunicorn", "agents_api.web:app", "-c", "gunicorn_conf.py"]
2 changes: 1 addition & 1 deletion agents-api/Dockerfile.worker
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ COPY . ./
ENV PYTHONUNBUFFERED=1
ENV GUNICORN_CMD_ARGS="--capture-output --enable-stdio-inheritance"

ENTRYPOINT ["uv", "run", "python", "-m", "agents_api.worker"]
ENTRYPOINT ["uv", "run", "--offline", "--no-sync", "python", "-m", "agents_api.worker"]
3 changes: 3 additions & 0 deletions agents-api/agents_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@

with workflow.unsafe.imports_passed_through():
import msgpack as msgpack

import os

2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/execute_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..common.exceptions.tools import IntegrationExecutionException
from ..common.protocol.tasks import ExecutionInput, StepContext
from ..env import testing
from ..models.tools import get_tool_args_from_metadata
from ..queries.tools import get_tool_args_from_metadata


@beartype
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/activities/task_steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from .base_evaluate import base_evaluate

# from .cozo_query_step import cozo_query_step
from .pg_query_step import pg_query_step
from .evaluate_step import evaluate_step
from .for_each_step import for_each_step
from .get_value_step import get_value_step
Expand Down
4 changes: 2 additions & 2 deletions agents-api/agents_api/activities/task_steps/pg_query_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ... import queries
from ...clients.pg import create_db_pool
from ...env import db_dsn, testing
from ...env import pg_dsn, testing


@alru_cache(maxsize=1)
Expand All @@ -18,7 +18,7 @@ async def get_db_pool(dsn: str):
async def pg_query_step(
query_name: str,
values: dict[str, Any],
dsn: str = db_dsn,
dsn: str = pg_dsn,
) -> Any:
pool = await get_db_pool(dsn=dsn)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)
from ...exceptions import LastErrorInput, TooManyRequestsError
from ...queries.executions.create_execution_transition import (
create_execution_transition_async,
create_execution_transition,
)
from ..utils import RateLimiter

Expand Down Expand Up @@ -52,7 +52,7 @@ async def transition_step(

# Create transition
try:
transition = await create_execution_transition_async(
transition = await create_execution_transition(
developer_id=context.execution_input.developer_id,
execution_id=context.execution_input.execution.id,
task_id=context.execution_input.task.id,
Expand Down
30 changes: 16 additions & 14 deletions agents-api/agents_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
@asynccontextmanager
async def lifespan(app: FastAPI):
# INIT POSTGRES #
db_dsn = os.environ.get("DB_DSN")
pg_dsn = os.environ.get("PG_DSN")

if not getattr(app.state, "postgres_pool", None):
app.state.postgres_pool = await create_db_pool(db_dsn)
app.state.postgres_pool = await create_db_pool(pg_dsn)

# INIT S3 #
s3_access_key = os.environ.get("S3_ACCESS_KEY")
Expand Down Expand Up @@ -67,7 +67,8 @@ async def lifespan(app: FastAPI):
lifespan=lifespan,
#
# Global dependencies
dependencies=[Depends(valid_content_length)],
# FIXME: This is blocking access to scalar
# dependencies=[Depends(valid_content_length)],
)

# Enable metrics
Expand All @@ -92,19 +93,20 @@ async def scalar_html():


# content-length validation
# FIXME: This is blocking access to scalar
# NOTE: This relies on client reporting the correct content-length header
# TODO: We should use streaming for large payloads
@app.middleware("http")
async def validate_content_length(
request: Request,
call_next: Callable[[Request], Coroutine[Any, Any, Response]],
):
content_length = request.headers.get("content-length")
# @app.middleware("http")
# async def validate_content_length(
# request: Request,
# call_next: Callable[[Request], Coroutine[Any, Any, Response]],
# ):
# content_length = request.headers.get("content-length")

if not content_length:
return Response(status_code=411, content="Content-Length header is required")
# if not content_length:
# return Response(status_code=411, content="Content-Length header is required")

if int(content_length) > max_payload_size:
return Response(status_code=413, content="Payload too large")
# if int(content_length) > max_payload_size:
# return Response(status_code=413, content="Payload too large")

return await call_next(request)
# return await call_next(request)
4 changes: 2 additions & 2 deletions agents-api/agents_api/clients/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import asyncpg

from ..env import db_dsn
from ..env import pg_dsn


async def _init_conn(conn):
Expand All @@ -16,5 +16,5 @@ async def _init_conn(conn):

async def create_db_pool(dsn: str | None = None):
return await asyncpg.create_pool(
dsn if dsn is not None else db_dsn, init=_init_conn
dsn if dsn is not None else pg_dsn, init=_init_conn
)
2 changes: 1 addition & 1 deletion agents-api/agents_api/common/protocol/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ class PartialTransition(create_partial_model(CreateTransitionRequest)):

class ExecutionInput(BaseModel):
developer_id: UUID
execution: Execution
execution: Execution | None = None
task: TaskSpecDef
agent: Agent
agent_tools: list[Tool | CreateToolRequest]
Expand Down
21 changes: 5 additions & 16 deletions agents-api/agents_api/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,15 @@
s3_secret_key: str | None = env.str("S3_SECRET_KEY", default=None)


# Cozo
# ----
cozo_host: str = env.str("COZO_HOST", default="http://127.0.0.1:9070")
cozo_auth: str = env.str("COZO_AUTH_TOKEN", default=None)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)
do_verify_developer: bool = env.bool("DO_VERIFY_DEVELOPER", default=True)
do_verify_developer_owns_resource: bool = env.bool(
"DO_VERIFY_DEVELOPER_OWNS_RESOURCE", default=True
)

# PostgreSQL
# ----
db_dsn: str = env.str(
"DB_DSN",
pg_dsn: str = env.str(
"PG_DSN",
default="postgres://postgres:[email protected]:5432/postgres?sslmode=disable",
)
summarization_model_name: str = env.str(
"SUMMARIZATION_MODEL_NAME", default="gpt-4-turbo"
)

query_timeout: float = env.float("QUERY_TIMEOUT", default=90.0)

Expand Down Expand Up @@ -156,8 +147,6 @@ def _parse_optional_int(val: str | None) -> int | None:
environment: Dict[str, Any] = dict(
debug=debug,
multi_tenant_mode=multi_tenant_mode,
cozo_host=cozo_host,
cozo_auth=cozo_auth,
sentry_dsn=sentry_dsn,
temporal_endpoint=temporal_endpoint,
temporal_task_queue=temporal_task_queue,
Expand Down
9 changes: 2 additions & 7 deletions agents-api/agents_api/queries/chat/prepare_chat_context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from typing import Any, TypeVar
from uuid import UUID

import sqlvalidator
from beartype import beartype

from ...common.protocol.sessions import ChatContext, make_session
from ...exceptions import InvalidSQLQuery
from ..utils import (
pg_query,
wrap_in_class,
Expand All @@ -15,8 +13,8 @@
T = TypeVar("T")


sql_query = sqlvalidator.parse(
"""SELECT * FROM
sql_query ="""
SELECT * FROM
(
SELECT jsonb_agg(u) AS users FROM (
SELECT
Expand Down Expand Up @@ -103,9 +101,6 @@
session_lookup.participant_type = 'agent'
) r
) AS toolsets"""
)
if not sql_query.is_valid():
raise InvalidSQLQuery("prepare_chat_context")


def _transform(d):
Expand Down
40 changes: 20 additions & 20 deletions agents-api/agents_api/queries/executions/prepare_execution_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

sql_query = """SELECT * FROM
(
SELECT to_jsonb(a) AS agents FROM (
SELECT to_jsonb(a) AS agent FROM (
SELECT * FROM agents
WHERE
developer_id = $1 AND
agent_id = $4
agent_id = (
SELECT agent_id FROM tasks
WHERE developer_id = $1 AND task_id = $2
)
LIMIT 1
) a
) AS agents,
) AS agent,
(
SELECT jsonb_agg(r) AS tools FROM (
SELECT * FROM tools
Expand All @@ -31,25 +34,25 @@
) r
) AS tools,
(
SELECT to_jsonb(t) AS tasks FROM (
SELECT to_jsonb(t) AS task FROM (
SELECT * FROM tasks
WHERE
developer_id = $1 AND
task_id = $2
LIMIT 1
) t
) AS tasks,
(
SELECT to_jsonb(e) AS executions FROM (
SELECT * FROM executions
WHERE
developer_id = $1 AND
task_id = $2 AND
execution_id = $3
LIMIT 1
) e
) AS executions;
) AS task;
"""
# (
# SELECT to_jsonb(e) AS execution FROM (
# SELECT * FROM latest_executions
# WHERE
# developer_id = $1 AND
# task_id = $2 AND
# execution_id = $3
# LIMIT 1
# ) e
# ) AS execution;


# @rewrap_exceptions(
Expand All @@ -70,7 +73,7 @@
transform=lambda d: {
**d,
"task": {
"tools": [*d["task"].pop("tools")],
"tools": d["tools"],
**d["task"],
},
"agent_tools": [
Expand All @@ -86,14 +89,11 @@ async def prepare_execution_input(
task_id: UUID,
execution_id: UUID,
) -> tuple[str, list]:
dummy_agent_id = UUID(int=0)

return (
sql_query,
[
str(developer_id),
str(task_id),
str(execution_id),
str(dummy_agent_id),
# str(execution_id),
],
)
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/tasks/create_or_update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ async def create_or_update_task(
tool.type,
tool.name,
tool.description,
getattr(tool, tool.type), # spec
getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(mode="json"), # spec
]
for tool in data.tools or []
]

# Generate workflows from task data using task_to_spec
workflows_spec = task_to_spec(data).model_dump(exclude_none=True, mode="json")
workflows_spec = task_to_spec(data).model_dump(mode="json")
workflow_params = []
for workflow in workflows_spec.get("workflows", []):
workflow_name = workflow.get("name")
Expand All @@ -211,7 +211,7 @@ async def create_or_update_task(
workflow_name, # $3
step_idx, # $4
step["kind_"], # $5
step[step["kind_"]], # $6
step, # $6
]
)

Expand Down
6 changes: 3 additions & 3 deletions agents-api/agents_api/queries/tasks/create_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ async def create_task(
tool.type,
tool.name,
tool.description,
getattr(tool, tool.type), # spec
getattr(tool, tool.type) and getattr(tool, tool.type).model_dump(mode="json"), # spec
]
for tool in data.tools or []
]

# Generate workflows from task data using task_to_spec
workflows_spec = task_to_spec(data).model_dump(exclude_none=True, mode="json")
workflows_spec = task_to_spec(data).model_dump(mode="json")
workflow_params = []
for workflow in workflows_spec.get("workflows", []):
workflow_name = workflow.get("name")
Expand All @@ -187,7 +187,7 @@ async def create_task(
workflow_name, # $4
step_idx, # $5
step["kind_"], # $6
step[step["kind_"]], # $7
step, # $7
]
)

Expand Down
7 changes: 1 addition & 6 deletions agents-api/agents_api/queries/tasks/get_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
CASE WHEN w.name IS NOT NULL THEN
jsonb_build_object(
'name', w.name,
'steps', jsonb_build_array(
jsonb_build_object(
w.step_type, w.step_definition,
'step_idx', w.step_idx -- Not sure if this is needed
)
)
'steps', jsonb_build_array(w.step_definition)
)
END
) FILTER (WHERE w.name IS NOT NULL),
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/tasks/patch_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ async def patch_task(
else:
workflow_query = new_workflows_query
workflow_params = []
workflows_spec = task_to_spec(data).model_dump(exclude_none=True, mode="json")
workflows_spec = task_to_spec(data).model_dump(mode="json")
for workflow in workflows_spec.get("workflows", []):
workflow_name = workflow.get("name")
steps = workflow.get("steps", [])
Expand Down
2 changes: 1 addition & 1 deletion agents-api/agents_api/queries/tasks/update_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def update_task(
]

# Generate workflows from task data
workflows_spec = task_to_spec(data).model_dump(exclude_none=True, mode="json")
workflows_spec = task_to_spec(data).model_dump(mode="json")
workflow_params = []
for workflow in workflows_spec.get("workflows", []):
workflow_name = workflow.get("name")
Expand Down
Loading

0 comments on commit 7798826

Please sign in to comment.