Skip to content

Commit

Permalink
Make list backfills endpoint use async
Browse files Browse the repository at this point in the history
This is a sort of hello world for having an route implemented using asyncio.
  • Loading branch information
dstandish committed Nov 20, 2024
1 parent 697edf0 commit 1efe639
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 16 deletions.
59 changes: 56 additions & 3 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@

from typing import TYPE_CHECKING, Sequence

from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from sqlalchemy.ext.asyncio import AsyncSession

from airflow.utils.db import get_query_count, get_query_count_async
from airflow.utils.session import NEW_SESSION, create_session, create_session_async, provide_session

if TYPE_CHECKING:
from sqlalchemy.orm import Session
Expand Down Expand Up @@ -53,10 +55,13 @@ def your_route(session: Annotated[Session, Depends(get_session)]):


def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
) -> Select:
if filters is None:
return base_select

for f in filters:
if f is None:
continue
Expand All @@ -65,6 +70,22 @@ def apply_filters_to_select(
return base_select


async def get_async_session() -> AsyncSession:
"""
Dependency for providing a session.
Example usage:
.. code:: python
@router.get("/your_path")
def your_route(session: Annotated[AsyncSession, Depends(get_async_session)]):
pass
"""
async with create_session_async() as session:
yield session


@provide_session
def paginated_select(
base_select: Select,
Expand Down Expand Up @@ -94,3 +115,35 @@ def paginated_select(
base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])

return base_select, total_entries


async def paginated_select_async(
*,
base_select: Select,
filters: Sequence[BaseParam | None] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: AsyncSession,
return_total_entries: bool = True,
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
base_select=base_select,
filters=filters,
)

total_entries = None
if return_total_entries:
total_entries = await get_query_count_async(base_select, session=session)

# TODO: Re-enable when permissions are handled. Readable / writable entities,
# for instance:
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(
base_select=base_select,
filters=[order_by, offset, limit],
)

return base_select, total_entries
18 changes: 10 additions & 8 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@
# under the License.
from __future__ import annotations

from typing import Annotated
from typing import TYPE_CHECKING, Annotated

from fastapi import Depends, HTTPException, status
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session

from airflow.api_fastapi.common.db.common import get_session, paginated_select
from airflow.api_fastapi.common.db.common import get_async_session, get_session, paginated_select_async
from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset, SortParam
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.datamodels.backfills import (
Expand All @@ -49,25 +50,26 @@
@backfills_router.get(
path="",
)
def list_backfills(
async def list_backfills(
dag_id: str,
limit: QueryLimit,
offset: QueryOffset,
order_by: Annotated[
SortParam,
Depends(SortParam(["id"], Backfill).dynamic_depends()),
],
session: Annotated[Session, Depends(get_session)],
session: Annotated[AsyncSession, Depends(get_async_session)],
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select(Backfill).where(Backfill.dag_id == dag_id),
select_stmt, total_entries = await paginated_select_async(
base_select=select(Backfill).where(Backfill.dag_id == dag_id),
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)
backfills = session.scalars(select_stmt)

backfills = await session.scalars(select_stmt)
if TYPE_CHECKING:
assert isinstance(total_entries, int)
return BackfillCollectionResponse(
backfills=[BackfillResponse.model_validate(x, from_attributes=True) for x in backfills],
total_entries=total_entries,
Expand Down
6 changes: 3 additions & 3 deletions airflow/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@
# this is achieved by the Session factory above.
NonScopedSession: Callable[..., SASession]
async_engine: AsyncEngine
create_async_session: Callable[..., AsyncSession]
session_maker_async: Callable[..., AsyncSession]

# The JSON library to use for DAG Serialization and De-Serialization
json = json
Expand Down Expand Up @@ -469,7 +469,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None):
global Session
global engine
global async_engine
global create_async_session
global session_maker_async
global NonScopedSession

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
Expand Down Expand Up @@ -498,7 +498,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None):

engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True)
async_engine = create_async_engine(SQL_ALCHEMY_CONN_ASYNC, future=True)
create_async_session = sessionmaker(
session_maker_async = sessionmaker(
bind=async_engine,
autocommit=False,
autoflush=False,
Expand Down
16 changes: 16 additions & 0 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from sqlalchemy.sql.elements import ClauseElement, TextClause
from sqlalchemy.sql.selectable import Select
Expand Down Expand Up @@ -1447,6 +1448,21 @@ def get_query_count(query_stmt: Select, *, session: Session) -> int:
return session.scalar(count_stmt)


async def get_query_count_async(query_stmt: Select, *, session: AsyncSession) -> int:
"""
Get count of a query.
A SELECT COUNT() FROM is issued against the subquery built from the
given statement. The ORDER BY clause is stripped from the statement
since it's unnecessary for COUNT, and can impact query planning and
degrade performance.
:meta private:
"""
count_stmt = select(func.count()).select_from(query_stmt.order_by(None).subquery())
return await session.scalar(count_stmt)


def check_query_exists(query_stmt: Select, *, session: Session) -> bool:
"""
Check whether there is at least one row matching a query.
Expand Down
18 changes: 18 additions & 0 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,24 @@ def create_session(scoped: bool = True) -> Generator[SASession, None, None]:
session.close()


@contextlib.asynccontextmanager
async def create_session_async():
"""
Context manager to create async session.
:meta private:
"""
from airflow.settings import session_maker_async

async with session_maker_async() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise


PS = ParamSpec("PS")
RT = TypeVar("RT")

Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def test_provide_session_with_kwargs(self):

@pytest.mark.asyncio
async def test_async_session(self):
from airflow.settings import create_async_session
from airflow.settings import session_maker_async

session = create_async_session()
session = session_maker_async()
session.add(Log(event="hihi1234"))
await session.commit()
my_special_log_event = await session.scalar(select(Log).where(Log.event == "hihi1234").limit(1))
Expand Down

0 comments on commit 1efe639

Please sign in to comment.