diff --git a/airflow/api_fastapi/common/db/common.py b/airflow/api_fastapi/common/db/common.py index 8b2c4a3d75ae0..1b71bbc47249d 100644 --- a/airflow/api_fastapi/common/db/common.py +++ b/airflow/api_fastapi/common/db/common.py @@ -24,8 +24,10 @@ from typing import TYPE_CHECKING, Sequence -from airflow.utils.db import get_query_count -from airflow.utils.session import NEW_SESSION, create_session, provide_session +from sqlalchemy.ext.asyncio import AsyncSession + +from airflow.utils.db import get_query_count, get_query_count_async +from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session if TYPE_CHECKING: from sqlalchemy.orm import Session @@ -53,10 +55,13 @@ def your_route(session: Annotated[Session, Depends(get_session)]): def apply_filters_to_select( - *, base_select: Select, filters: Sequence[BaseParam | None] | None = None + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, ) -> Select: if filters is None: return base_select + for f in filters: if f is None: continue @@ -65,6 +70,22 @@ def apply_filters_to_select( return base_select +async def get_async_session() -> AsyncSession: + """ + Dependency for providing a session. + + Example usage: + + .. code:: python + + @router.get("/your_path") + def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]): + pass + """ + async with create_session_async() as session: + yield session + + @provide_session def paginated_select( base_select: Select, @@ -94,3 +115,35 @@ def paginated_select( base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit]) return base_select, total_entries + + +async def paginated_select_async( + *, + base_select: Select, + filters: Sequence[BaseParam | None] | None = None, + order_by: BaseParam | None = None, + offset: BaseParam | None = None, + limit: BaseParam | None = None, + session: AsyncSession, + return_total_entries: bool = True, +) -> tuple[Select, int | None]: + base_select = apply_filters_to_select( + base_select=base_select, + filters=filters, + ) + + total_entries = None + if return_total_entries: + total_entries = await get_query_count_async(base_select, session=session) + + # TODO: Re-enable when permissions are handled. Readable / writable entities, + # for instance: + # readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user) + # dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags)) + + base_select = apply_filters_to_select( + base_select=base_select, + filters=[order_by, offset, limit], + ) + + return base_select, total_entries diff --git a/airflow/api_fastapi/core_api/routes/public/backfills.py b/airflow/api_fastapi/core_api/routes/public/backfills.py index 6f702171fec97..868f8d5bb586b 100644 --- a/airflow/api_fastapi/core_api/routes/public/backfills.py +++ b/airflow/api_fastapi/core_api/routes/public/backfills.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations -from typing import Annotated +from typing import TYPE_CHECKING, Annotated from fastapi import Depends, HTTPException, status from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session -from airflow.api_fastapi.common.db.common import get_session, paginated_select +from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.datamodels.backfills import ( @@ -49,7 +50,7 @@ @backfills_router.get( path="", ) -def list_backfills( +async def list_backfills( dag_id: str, limit: QueryLimit, offset: QueryOffset, @@ -57,17 +58,18 @@ def list_backfills( SortParam, Depends(SortParam(["id"], Backfill).dynamic_depends()), ], - session: Annotated[Session, Depends(get_session)], + session: Annotated[AsyncSession, Depends(get_async_session)], ) -> BackfillCollectionResponse: - select_stmt, total_entries = paginated_select( - select(Backfill).where(Backfill.dag_id == dag_id), + select_stmt, total_entries = await paginated_select_async( + base_select=select(Backfill).where(Backfill.dag_id == dag_id), order_by=order_by, offset=offset, limit=limit, session=session, ) - backfills = session.scalars(select_stmt) - + backfills = await session.scalars(select_stmt) + if TYPE_CHECKING: + assert isinstance(total_entries, int) return BackfillCollectionResponse( backfills=[BackfillResponse.model_validate(x, from_attributes=True) for x in backfills], total_entries=total_entries, diff --git a/airflow/settings.py b/airflow/settings.py index 5b458efcba473..4078b635b888b 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -111,7 +111,7 @@ # this is achieved by the Session factory above. NonScopedSession: Callable[..., SASession] async_engine: AsyncEngine -create_async_session: Callable[..., AsyncSession] +session_maker_async: Callable[..., AsyncSession] # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine global async_engine - global create_async_session + global session_maker_async global NonScopedSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": @@ -498,7 +498,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True) - create_async_session = sessionmaker( + session_maker_async = sessionmaker( bind=async_engine, autocommit=False, autoflush=False, diff --git a/airflow/utils/db.py b/airflow/utils/db.py index d8939a117317f..00d20bbbca9fa 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -70,6 +70,7 @@ from alembic.runtime.environment import EnvironmentContext from alembic.script import ScriptDirectory from sqlalchemy.engine import Row + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ClauseElement, TextClause from sqlalchemy.sql.selectable import Select @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int: return session.scalar(count_stmt) +async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> int: + """ + Get count of a query. + + A SELECT COUNT() FROM is issued against the subquery built from the + given statement. The ORDER BY clause is stripped from the statement + since it's unnecessary for COUNT, and can impact query planning and + degrade performance. + + :meta private: + """ + count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery()) + return await session.scalar(count_stmt) + + def check_query_exists(query_stmt: Select, *, session: Session) -> bool: """ Check whether there is at least one row matching a query. diff --git a/airflow/utils/session.py b/airflow/utils/session.py index a63d3f3f937a8..6336eed6deb76 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_session_async(): + """ + Context manager to create async session. + + :meta private: + """ + from airflow.settings import session_maker_async + + async with session_maker_async() as session: + try: + yield session + await session.commit() + except Exception: + await session.rollback() + raise + + PS = ParamSpec("PS") RT = TypeVar("RT") diff --git a/tests/utils/test_session.py b/tests/utils/test_session.py index 02cba9e070dc4..db21b3c1f4edb 100644 --- a/tests/utils/test_session.py +++ b/tests/utils/test_session.py @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self): @pytest.mark.asyncio async def test_async_session(self): - from airflow.settings import create_async_session + from airflow.settings import session_maker_async - session = create_async_session() + session = session_maker_async() session.add(Log(event="hihi1234")) await session.commit() my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))