diff --git a/agents-api/agents_api/queries/executions/__init__.py b/agents-api/agents_api/queries/executions/__init__.py index 32a72b75c..1a298a551 100644 --- a/agents-api/agents_api/queries/executions/__init__.py +++ b/agents-api/agents_api/queries/executions/__init__.py @@ -1,5 +1,15 @@ # ruff: noqa: F401, F403, F405 +""" +The `execution` module provides SQL query functions for managing executions +in the TimescaleDB database. This includes operations for: + +- Creating new executions +- Deleting executions +- Retrieving execution history +- Listing executions with filtering and pagination +""" + from .count_executions import count_executions from .create_execution import create_execution from .create_execution_transition import create_execution_transition @@ -9,3 +19,16 @@ from .list_executions import list_executions from .lookup_temporal_data import lookup_temporal_data from .prepare_execution_input import prepare_execution_input + + +__all__ = [ + "count_executions", + "create_execution", + "create_execution_transition", + "get_execution", + "get_execution_transition", + "list_execution_transitions", + "list_executions", + "lookup_temporal_data", + "prepare_execution_input", +] diff --git a/agents-api/agents_api/queries/executions/count_executions.py b/agents-api/agents_api/queries/executions/count_executions.py index 7073be7b8..cd9fd8a9d 100644 --- a/agents-api/agents_api/queries/executions/count_executions.py +++ b/agents-api/agents_api/queries/executions/count_executions.py @@ -1,30 +1,40 @@ -from typing import Any, Literal, TypeVar +from typing import Any, Literal from uuid import UUID from beartype import beartype - +from fastapi import HTTPException +from sqlglot import parse_one +import asyncpg from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """SELECT COUNT(*) FROM latest_executions +# Query to count executions for a given task +execution_count_query = parse_one(""" +SELECT COUNT(*) FROM latest_executions WHERE developer_id = $1 AND task_id = $2; -""" +""").sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=400), -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) +@rewrap_exceptions( +{ + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or task does not exist" + ), +} +) @wrap_in_class(dict, one=True) @pg_query @beartype @@ -33,4 +43,18 @@ async def count_executions( developer_id: UUID, task_id: UUID, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: - return (sql_query, [developer_id, task_id], "fetchrow") + """ + Count the number of executions for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for counting executions. + """ + return ( + execution_count_query, + [developer_id, task_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/create_execution.py b/agents-api/agents_api/queries/executions/create_execution.py index 6c1737f2b..3f77be30e 100644 --- a/agents-api/agents_api/queries/executions/create_execution.py +++ b/agents-api/agents_api/queries/executions/create_execution.py @@ -8,16 +8,19 @@ from ...common.utils.datetime import utcnow from ...common.utils.types import dict_like from ...metrics.counters import increase_counter +from sqlglot import parse_one +import asyncpg +from fastapi import HTTPException from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) from .constants import OUTPUT_UNNEST_KEY -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") -sql_query = """ +create_execution_query = parse_one(""" INSERT INTO executions ( developer_id, @@ -37,16 +40,23 @@ 1 ) RETURNING *; -""" +""").sql(pretty=True) -# @rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=400), -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) +@rewrap_exceptions( +{ + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or task does not exist" + ), +} +) @wrap_in_class( Execution, one=True, @@ -67,6 +77,18 @@ async def create_execution( execution_id: UUID | None = None, data: Annotated[CreateExecutionRequest | dict, dict_like(CreateExecutionRequest)], ) -> tuple[str, list]: + """ + Create a new execution. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + execution_id (UUID | None): The ID of the execution. + data (CreateExecutionRequest | dict): The data for the execution. + + Returns: + tuple[str, list]: SQL query and parameters for creating the execution. + """ execution_id = execution_id or uuid7() developer_id = str(developer_id) @@ -86,7 +108,7 @@ async def create_execution( execution_data["output"] = {OUTPUT_UNNEST_KEY: execution_data["output"]} return ( - sql_query, + create_execution_query, [ developer_id, task_id, diff --git a/agents-api/agents_api/queries/executions/create_execution_transition.py b/agents-api/agents_api/queries/executions/create_execution_transition.py index 8ff1be47d..ac12a84af 100644 --- a/agents-api/agents_api/queries/executions/create_execution_transition.py +++ b/agents-api/agents_api/queries/executions/create_execution_transition.py @@ -8,14 +8,20 @@ CreateTransitionRequest, Transition, ) +import asyncpg +from fastapi import HTTPException +from sqlglot import parse_one from ...common.utils.datetime import utcnow from ...metrics.counters import increase_counter from ..utils import ( pg_query, wrap_in_class, + rewrap_exceptions, + partialclass, ) -sql_query = """ +# Query to create a transition +create_execution_transition_query = parse_one(""" INSERT INTO transitions ( execution_id, @@ -43,7 +49,7 @@ $10 ) RETURNING *; -""" +""").sql(pretty=True) def validate_transition_targets(data: CreateTransitionRequest) -> None: @@ -80,13 +86,20 @@ def validate_transition_targets(data: CreateTransitionRequest) -> None: raise ValueError(f"Invalid transition type: {data.type}") -# rewrap_exceptions( -# { -# QueryException: partialclass(HTTPException, status_code=400), -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) +@rewrap_exceptions( +{ + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or task does not exist" + ), +} +) @wrap_in_class( Transition, transform=lambda d: { @@ -111,6 +124,19 @@ async def create_execution_transition( transition_id: UUID | None = None, task_token: str | None = None, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Create a new execution transition. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + data (CreateTransitionRequest): The data for the transition. + transition_id (UUID | None): The ID of the transition. + task_token (str | None): The task token. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for creating the transition. + """ transition_id = transition_id or uuid7() data.metadata = data.metadata or {} data.execution_id = execution_id @@ -140,7 +166,7 @@ async def create_execution_transition( ) return ( - sql_query, + create_execution_transition_query, [ execution_id, transition_id, diff --git a/agents-api/agents_api/queries/executions/create_temporal_lookup.py b/agents-api/agents_api/queries/executions/create_temporal_lookup.py index 7303304a9..f352cb151 100644 --- a/agents-api/agents_api/queries/executions/create_temporal_lookup.py +++ b/agents-api/agents_api/queries/executions/create_temporal_lookup.py @@ -1,17 +1,20 @@ -from typing import TypeVar -from uuid import UUID from beartype import beartype from temporalio.client import WorkflowHandle +from sqlglot import parse_one +import asyncpg +from fastapi import HTTPException +from uuid import UUID from ...metrics.counters import increase_counter from ..utils import ( pg_query, + rewrap_exceptions, + partialclass, ) -T = TypeVar("T") - -sql_query = """ +# Query to create a temporal lookup +create_temporal_lookup_query = parse_one(""" INSERT INTO temporal_executions_lookup ( execution_id, @@ -29,17 +32,23 @@ $5 ) RETURNING *; -""" +""").sql(pretty=True) -# @rewrap_exceptions( -# { -# AssertionError: partialclass(HTTPException, status_code=404), -# QueryException: partialclass(HTTPException, status_code=400), -# ValidationError: partialclass(HTTPException, status_code=400), -# TypeError: partialclass(HTTPException, status_code=400), -# } -# ) +@rewrap_exceptions( +{ + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or task does not exist" + ), +} +) @pg_query @increase_counter("create_temporal_lookup") @beartype @@ -49,11 +58,22 @@ async def create_temporal_lookup( execution_id: UUID, workflow_handle: WorkflowHandle, ) -> tuple[str, list]: + """ + Create a temporal lookup for a given execution. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + workflow_handle (WorkflowHandle): The workflow handle. + + Returns: + tuple[str, list]: SQL query and parameters for creating the temporal lookup. + """ developer_id = str(developer_id) execution_id = str(execution_id) return ( - sql_query, + create_temporal_lookup_query, [ execution_id, workflow_handle.id, diff --git a/agents-api/agents_api/queries/executions/get_execution.py b/agents-api/agents_api/queries/executions/get_execution.py index 993052157..52c20bdb1 100644 --- a/agents-api/agents_api/queries/executions/get_execution.py +++ b/agents-api/agents_api/queries/executions/get_execution.py @@ -1,9 +1,10 @@ -from typing import Any, Literal, TypeVar +from typing import Literal from uuid import UUID -from asyncpg.exceptions import NoDataFoundError -from beartype import beartype +import asyncpg from fastapi import HTTPException +from sqlglot import parse_one +from beartype import beartype from ...autogen.openapi_model import Execution from ..utils import ( @@ -14,20 +15,22 @@ ) from .constants import OUTPUT_UNNEST_KEY -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to get an execution +get_execution_query = parse_one(""" SELECT * FROM latest_executions WHERE execution_id = $1 LIMIT 1; -""" +""").sql(pretty=True) @rewrap_exceptions( { - NoDataFoundError: partialclass(HTTPException, status_code=404), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), } ) @wrap_in_class( @@ -47,4 +50,17 @@ async def get_execution( *, execution_id: UUID, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: - return (sql_query, [execution_id], "fetchrow") + """ + Get an execution by its ID. + + Parameters: + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting the execution. + """ + return ( + get_execution_query, + [execution_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/get_execution_transition.py b/agents-api/agents_api/queries/executions/get_execution_transition.py index 8998c0c53..2b4c78684 100644 --- a/agents-api/agents_api/queries/executions/get_execution_transition.py +++ b/agents-api/agents_api/queries/executions/get_execution_transition.py @@ -1,10 +1,10 @@ from typing import Any, Literal, TypeVar from uuid import UUID -from asyncpg.exceptions import NoDataFoundError +import asyncpg from beartype import beartype from fastapi import HTTPException - +from sqlglot import parse_one from ...autogen.openapi_model import Transition from ..utils import ( partialclass, @@ -13,16 +13,14 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to get an execution transition +get_execution_transition_query = parse_one(""" SELECT * FROM transitions WHERE transition_id = $1 OR task_token = $2 LIMIT 1; -""" +""").sql(pretty=True) def _transform(d): @@ -42,9 +40,18 @@ def _transform(d): @rewrap_exceptions( - { - NoDataFoundError: partialclass(HTTPException, status_code=404), - } +{ + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No executions found for the specified task" + ), + asyncpg.ForeignKeyViolationError: partialclass( + HTTPException, + status_code=404, + detail="The specified developer or task does not exist" + ), +} ) @wrap_in_class(Transition, one=True, transform=_transform) @pg_query @@ -55,13 +62,24 @@ async def get_execution_transition( transition_id: UUID | None = None, task_token: str | None = None, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get an execution transition by its ID or task token. + + Parameters: + developer_id (UUID): The ID of the developer. + transition_id (UUID | None): The ID of the transition. + task_token (str | None): The task token. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting the execution transition. + """ # At least one of `transition_id` or `task_token` must be provided assert ( transition_id or task_token ), "At least one of `transition_id` or `task_token` must be provided." return ( - sql_query, + get_execution_transition_query, [ transition_id, task_token, diff --git a/agents-api/agents_api/queries/executions/get_paused_execution_token.py b/agents-api/agents_api/queries/executions/get_paused_execution_token.py index c6f9c8211..f9c981686 100644 --- a/agents-api/agents_api/queries/executions/get_paused_execution_token.py +++ b/agents-api/agents_api/queries/executions/get_paused_execution_token.py @@ -1,9 +1,10 @@ from typing import Any, Literal, TypeVar from uuid import UUID -from asyncpg.exceptions import NoDataFoundError +import asyncpg from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ..utils import ( partialclass, @@ -12,22 +13,24 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to get a paused execution token +get_paused_execution_token_query = parse_one(""" SELECT * FROM transitions WHERE execution_id = $1 - AND type = 'wait' -ORDER BY created_at DESC -LIMIT 1; -""" + AND type = 'wait' + ORDER BY created_at DESC + LIMIT 1; +""").sql(pretty=True) @rewrap_exceptions( { - NoDataFoundError: partialclass(HTTPException, status_code=404), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No paused executions found for the specified task" + ), } ) @wrap_in_class(dict, one=True) @@ -38,6 +41,16 @@ async def get_paused_execution_token( developer_id: UUID, execution_id: UUID, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get a paused execution token for a given execution. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting a paused execution token. + """ execution_id = str(execution_id) # TODO: what to do with this query? @@ -55,7 +68,8 @@ async def get_paused_execution_token( # """ return ( - sql_query, + get_paused_execution_token_query, [execution_id], "fetchrow", ) + diff --git a/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py index 41eb3e933..123516c94 100644 --- a/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py +++ b/agents-api/agents_api/queries/executions/get_temporal_workflow_data.py @@ -1,9 +1,10 @@ -from typing import Any, Literal, TypeVar +from typing import Literal from uuid import UUID -from asyncpg.exceptions import NoDataFoundError +import asyncpg from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ..utils import ( partialclass, @@ -12,20 +13,22 @@ wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to get temporal workflow data +get_temporal_workflow_data_query = parse_one(""" SELECT id, run_id, result_run_id, first_execution_run_id FROM temporal_executions_lookup WHERE execution_id = $1 LIMIT 1; -""" +""").sql(pretty=True) @rewrap_exceptions( { - NoDataFoundError: partialclass(HTTPException, status_code=404), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No temporal workflow data found for the specified execution" + ), } ) @wrap_in_class(dict, one=True) @@ -35,11 +38,20 @@ async def get_temporal_workflow_data( *, execution_id: UUID, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Get temporal workflow data for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for getting temporal workflow data. + """ # Executions are allowed direct GET access if they have execution_id execution_id = str(execution_id) return ( - sql_query, + get_temporal_workflow_data_query, [ execution_id, ], diff --git a/agents-api/agents_api/queries/executions/list_execution_transitions.py b/agents-api/agents_api/queries/executions/list_execution_transitions.py index 5e0836aa6..07260a5d1 100644 --- a/agents-api/agents_api/queries/executions/list_execution_transitions.py +++ b/agents-api/agents_api/queries/executions/list_execution_transitions.py @@ -1,20 +1,16 @@ -from typing import Any, Literal, TypeVar +from typing import Literal from uuid import UUID -from asyncpg.exceptions import ( - InvalidRowCountInLimitClauseError, - InvalidRowCountInResultOffsetClauseError, -) +import asyncpg from beartype import beartype from fastapi import HTTPException +from sqlglot import parse_one from ...autogen.openapi_model import Transition from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to list execution transitions +list_execution_transitions_query = parse_one(""" SELECT * FROM transitions WHERE execution_id = $1 @@ -24,7 +20,7 @@ CASE WHEN $4 = 'updated_at' AND $5 = 'asc' THEN updated_at END ASC NULLS LAST, CASE WHEN $4 = 'updated_at' AND $5 = 'desc' THEN updated_at END DESC NULLS LAST LIMIT $2 OFFSET $3; -""" +""").sql(pretty=True) def _transform(d): @@ -45,9 +41,15 @@ def _transform(d): @rewrap_exceptions( { - InvalidRowCountInLimitClauseError: partialclass(HTTPException, status_code=400), - InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, status_code=400 + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause" + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause" ), } ) @@ -65,8 +67,21 @@ async def list_execution_transitions( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> tuple[str, list]: + """ + List execution transitions for a given execution. + + Parameters: + execution_id (UUID): The ID of the execution. + limit (int): The number of transitions to return. + offset (int): The number of transitions to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + + Returns: + tuple[str, list]: SQL query and parameters for listing execution transitions. + """ return ( - sql_query, + list_execution_transitions_query, [ str(execution_id), limit, diff --git a/agents-api/agents_api/queries/executions/list_executions.py b/agents-api/agents_api/queries/executions/list_executions.py index 2bb467fb8..2ffc0c003 100644 --- a/agents-api/agents_api/queries/executions/list_executions.py +++ b/agents-api/agents_api/queries/executions/list_executions.py @@ -1,13 +1,10 @@ from typing import Any, Literal, TypeVar from uuid import UUID -from asyncpg.exceptions import ( - InvalidRowCountInLimitClauseError, - InvalidRowCountInResultOffsetClauseError, -) +import asyncpg from beartype import beartype from fastapi import HTTPException - +from sqlglot import parse_one from ...autogen.openapi_model import Execution from ..utils import ( partialclass, @@ -17,10 +14,8 @@ ) from .constants import OUTPUT_UNNEST_KEY -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to list executions +list_executions_query = parse_one(""" SELECT * FROM latest_executions WHERE developer_id = $1 AND @@ -31,14 +26,20 @@ CASE WHEN $3 = 'updated_at' AND $4 = 'asc' THEN updated_at END ASC NULLS LAST, CASE WHEN $3 = 'updated_at' AND $4 = 'desc' THEN updated_at END DESC NULLS LAST LIMIT $5 OFFSET $6; -""" +""").sql(pretty=True) @rewrap_exceptions( { - InvalidRowCountInLimitClauseError: partialclass(HTTPException, status_code=400), - InvalidRowCountInResultOffsetClauseError: partialclass( - HTTPException, status_code=400 + asyncpg.InvalidRowCountInLimitClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid limit clause" + ), + asyncpg.InvalidRowCountInResultOffsetClauseError: partialclass( + HTTPException, + status_code=400, + detail="Invalid offset clause" ), } ) @@ -63,8 +64,22 @@ async def list_executions( sort_by: Literal["created_at", "updated_at"] = "created_at", direction: Literal["asc", "desc"] = "desc", ) -> tuple[str, list]: + """ + List executions for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + limit (int): The number of executions to return. + offset (int): The number of executions to skip. + sort_by (Literal["created_at", "updated_at"]): The field to sort by. + direction (Literal["asc", "desc"]): The direction to sort by. + + Returns: + tuple[str, list]: SQL query and parameters for listing executions. + """ return ( - sql_query, + list_executions_query, [ developer_id, task_id, diff --git a/agents-api/agents_api/queries/executions/lookup_temporal_data.py b/agents-api/agents_api/queries/executions/lookup_temporal_data.py index 59c3aef32..55d0bbd90 100644 --- a/agents-api/agents_api/queries/executions/lookup_temporal_data.py +++ b/agents-api/agents_api/queries/executions/lookup_temporal_data.py @@ -1,26 +1,28 @@ from typing import Any, Literal, TypeVar from uuid import UUID -from asyncpg.exceptions import NoDataFoundError +import asyncpg from beartype import beartype from fastapi import HTTPException - +from sqlglot import parse_one from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """ +# Query to lookup temporal data +lookup_temporal_data_query = parse_one(""" SELECT * FROM temporal_executions_lookup WHERE execution_id = $1 LIMIT 1; -""" +""").sql(pretty=True) @rewrap_exceptions( { - NoDataFoundError: partialclass(HTTPException, status_code=404), + asyncpg.NoDataFoundError: partialclass( + HTTPException, + status_code=404, + detail="No temporal data found for the specified execution" + ), } ) @wrap_in_class(dict, one=True) @@ -31,7 +33,21 @@ async def lookup_temporal_data( developer_id: UUID, # TODO: what to do with this parameter? execution_id: UUID, ) -> tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: + """ + Lookup temporal data for a given execution. + + Parameters: + developer_id (UUID): The ID of the developer. + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list, Literal["fetch", "fetchmany", "fetchrow"]]: SQL query and parameters for looking up temporal data. + """ developer_id = str(developer_id) execution_id = str(execution_id) - return (sql_query, [execution_id], "fetchrow") + return ( + lookup_temporal_data_query, + [execution_id], + "fetchrow", + ) diff --git a/agents-api/agents_api/queries/executions/prepare_execution_input.py b/agents-api/agents_api/queries/executions/prepare_execution_input.py index 51ddec7a6..b751d2eb0 100644 --- a/agents-api/agents_api/queries/executions/prepare_execution_input.py +++ b/agents-api/agents_api/queries/executions/prepare_execution_input.py @@ -1,18 +1,17 @@ -from typing import Any, TypeVar +from typing import Any from uuid import UUID from beartype import beartype - +from sqlglot import parse_one from ...common.protocol.tasks import ExecutionInput from ..utils import ( pg_query, wrap_in_class, ) -ModelT = TypeVar("ModelT", bound=Any) -T = TypeVar("T") - -sql_query = """SELECT * FROM +# Query to prepare execution input +prepare_execution_input_query = parse_one(""" +SELECT * FROM ( SELECT to_jsonb(a) AS agent FROM ( SELECT * FROM agents @@ -42,7 +41,7 @@ LIMIT 1 ) t ) AS task; -""" +""").sql(pretty=True) # ( # SELECT to_jsonb(e) AS execution FROM ( # SELECT * FROM latest_executions @@ -89,8 +88,19 @@ async def prepare_execution_input( task_id: UUID, execution_id: UUID, ) -> tuple[str, list]: + """ + Prepare the execution input for a given task. + + Parameters: + developer_id (UUID): The ID of the developer. + task_id (UUID): The ID of the task. + execution_id (UUID): The ID of the execution. + + Returns: + tuple[str, list]: SQL query and parameters for preparing the execution input. + """ return ( - sql_query, + prepare_execution_input_query, [ str(developer_id), str(task_id), diff --git a/agents-api/agents_api/queries/tasks/create_or_update_task.py b/agents-api/agents_api/queries/tasks/create_or_update_task.py index 09b4a192d..9adde2d73 100644 --- a/agents-api/agents_api/queries/tasks/create_or_update_task.py +++ b/agents-api/agents_api/queries/tasks/create_or_update_task.py @@ -42,6 +42,7 @@ ) """).sql(pretty=True) +# Define the raw SQL query for creating or updating a task task_query = parse_one(""" WITH current_version AS ( SELECT COALESCE( diff --git a/agents-api/agents_api/queries/tasks/delete_task.py b/agents-api/agents_api/queries/tasks/delete_task.py index 20e03e28a..575397426 100644 --- a/agents-api/agents_api/queries/tasks/delete_task.py +++ b/agents-api/agents_api/queries/tasks/delete_task.py @@ -7,19 +7,21 @@ from ...autogen.openapi_model import ResourceDeletedResponse from ...common.utils.datetime import utcnow -from ...metrics.counters import increase_counter +from sqlglot import parse_one from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -workflow_query = """ +# Define the raw SQL query for deleting workflows +workflow_query = parse_one(""" DELETE FROM workflows WHERE developer_id = $1 AND task_id = $2; -""" +""").sql(pretty=True) -task_query = """ +# Define the raw SQL query for deleting tasks +task_query = parse_one(""" DELETE FROM tasks WHERE developer_id = $1 AND task_id = $2 RETURNING task_id; -""" +""").sql(pretty=True) @rewrap_exceptions( @@ -49,7 +51,6 @@ "deleted_at": utcnow(), }, ) -@increase_counter("delete_task") @pg_query @beartype async def delete_task( diff --git a/agents-api/agents_api/queries/tasks/get_task.py b/agents-api/agents_api/queries/tasks/get_task.py index 1f0dd00cd..902a4fcde 100644 --- a/agents-api/agents_api/queries/tasks/get_task.py +++ b/agents-api/agents_api/queries/tasks/get_task.py @@ -6,10 +6,11 @@ from fastapi import HTTPException from ...common.protocol.tasks import spec_to_task -from ...metrics.counters import increase_counter +from sqlglot import parse_one from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -get_task_query = """ +# Define the raw SQL query for getting a task +get_task_query = parse_one(""" SELECT t.*, COALESCE( @@ -35,7 +36,7 @@ WHERE developer_id = $1 AND task_id = $2 ) GROUP BY t.developer_id, t.task_id, t.canonical_name, t.agent_id, t.version; -""" +""").sql(pretty=True) @rewrap_exceptions( @@ -58,7 +59,6 @@ } ) @wrap_in_class(spec_to_task, one=True) -@increase_counter("get_task") @pg_query @beartype async def get_task( diff --git a/agents-api/agents_api/queries/tasks/list_tasks.py b/agents-api/agents_api/queries/tasks/list_tasks.py index 8a284fd2c..9c8d765a4 100644 --- a/agents-api/agents_api/queries/tasks/list_tasks.py +++ b/agents-api/agents_api/queries/tasks/list_tasks.py @@ -6,9 +6,9 @@ from fastapi import HTTPException from ...common.protocol.tasks import spec_to_task -from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class +# Define the raw SQL query for listing tasks list_tasks_query = """ SELECT t.*, @@ -17,12 +17,7 @@ CASE WHEN w.name IS NOT NULL THEN jsonb_build_object( 'name', w.name, - 'steps', jsonb_build_array( - jsonb_build_object( - w.step_type, w.step_definition, - 'step_idx', w.step_idx -- Not sure if this is needed - ) - ) + 'steps', jsonb_build_array(w.step_definition) ) END ) FILTER (WHERE w.name IS NOT NULL), @@ -66,7 +61,6 @@ } ) @wrap_in_class(spec_to_task) -@increase_counter("list_tasks") @pg_query @beartype async def list_tasks( diff --git a/agents-api/agents_api/queries/tasks/patch_task.py b/agents-api/agents_api/queries/tasks/patch_task.py index 48111a333..a7b3f809e 100644 --- a/agents-api/agents_api/queries/tasks/patch_task.py +++ b/agents-api/agents_api/queries/tasks/patch_task.py @@ -12,25 +12,6 @@ from ...metrics.counters import increase_counter from ..utils import partialclass, pg_query, rewrap_exceptions, wrap_in_class -# # Update task query using UPDATE -# update_task_query = parse_one(""" -# UPDATE tasks -# SET -# version = version + 1, -# canonical_name = $2, -# agent_id = $4, -# metadata = $5, -# name = $6, -# description = $7, -# inherit_tools = $8, -# input_schema = $9::jsonb, -# updated_at = NOW() -# WHERE -# developer_id = $1 -# AND task_id = $3 -# RETURNING *; -# """).sql(pretty=True) - # Update task query using INSERT with version increment patch_task_query = parse_one(""" WITH current_version AS ( @@ -215,6 +196,14 @@ async def patch_task( ) return [ - (patch_task_query, patch_task_params, "fetchrow"), - (workflow_query, workflow_params, "fetchmany"), + ( + patch_task_query, + patch_task_params, + "fetchrow", + ), + ( + workflow_query, + workflow_params, + "fetchmany", + ), ] diff --git a/agents-api/agents_api/queries/tasks/update_task.py b/agents-api/agents_api/queries/tasks/update_task.py index 0379e0312..f97384d52 100644 --- a/agents-api/agents_api/queries/tasks/update_task.py +++ b/agents-api/agents_api/queries/tasks/update_task.py @@ -100,7 +100,12 @@ @wrap_in_class( ResourceUpdatedResponse, one=True, - transform=lambda d: {"id": d["task_id"], "updated_at": utcnow()}, + transform=lambda d: + { + "id": d["task_id"], + "updated_at": utcnow(), + "jobs": [], + }, ) @increase_counter("update_task") @pg_query(return_index=0) diff --git a/agents-api/agents_api/queries/users/create_user.py b/agents-api/agents_api/queries/users/create_user.py index 982d7a97e..8d86efd7a 100644 --- a/agents-api/agents_api/queries/users/create_user.py +++ b/agents-api/agents_api/queries/users/create_user.py @@ -73,7 +73,7 @@ async def create_user( tuple[str, list]: A tuple containing the SQL query and its parameters. """ user_id = user_id or uuid7() - metadata = data.metadata.model_dump(mode="json") or {} + metadata = data.metadata or {} params = [ developer_id, # $1 diff --git a/agents-api/tests/test_execution_queries.py b/agents-api/tests/test_execution_queries.py index 2abe9e5b4..316d91bde 100644 --- a/agents-api/tests/test_execution_queries.py +++ b/agents-api/tests/test_execution_queries.py @@ -26,7 +26,7 @@ test_task, ) -MODEL = "gpt-4o-mini-mini" +MODEL = "gpt-4o-mini" @test("query: create execution") @@ -51,6 +51,9 @@ async def _(dsn=pg_dsn, developer_id=test_developer_id, task=test_task): connection_pool=pool, ) + assert execution.status == "queued" + assert execution.input == {"test": "test"} + @test("query: get execution") async def _(dsn=pg_dsn, developer_id=test_developer_id, execution=test_execution): diff --git a/agents-api/tests/test_files_routes.py b/agents-api/tests/test_file_routes.py similarity index 93% rename from agents-api/tests/test_files_routes.py rename to agents-api/tests/test_file_routes.py index f0dca00bf..05507a786 100644 --- a/agents-api/tests/test_files_routes.py +++ b/agents-api/tests/test_file_routes.py @@ -48,12 +48,12 @@ async def _(make_request=make_request, s3_client=s3_client): assert response.status_code == 202 - # response = make_request( - # method="GET", - # url=f"/files/1", - # ) + response = make_request( + method="GET", + url=f"/files/{file_id}", + ) - # assert response.status_code == 404 + assert response.status_code == 404 @test("route: get file")