Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make filters param optional and fix typing #44226

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 45 additions & 11 deletions airflow/api_fastapi/common/db/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Database helpers for Airflow REST API.
:meta private:
"""

from __future__ import annotations

from typing import TYPE_CHECKING, Sequence
from typing import TYPE_CHECKING, Literal, Sequence, overload

from airflow.utils.db import get_query_count
from airflow.utils.session import NEW_SESSION, create_session, provide_session
Expand Down Expand Up @@ -47,30 +52,59 @@ def your_route(session: Annotated[Session, Depends(get_session)]):
yield session


def apply_filters_to_select(base_select: Select, filters: Sequence[BaseParam | None]) -> Select:
base_select = base_select
for filter in filters:
if filter is None:
def apply_filters_to_select(
*, base_select: Select, filters: Sequence[BaseParam | None] | None = None
) -> Select:
if filters is None:
return base_select
for f in filters:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
if f is None:
dstandish marked this conversation as resolved.
Show resolved Hide resolved
continue
base_select = filter.to_orm(base_select)
base_select = f.to_orm(base_select)
dstandish marked this conversation as resolved.
Show resolved Hide resolved

return base_select


@overload
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[True] = True,
) -> tuple[Select, int]: ...


@overload
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: Literal[False],
) -> tuple[Select, None]: ...


@provide_session
def paginated_select(
*,
select: Select,
filters: Sequence[BaseParam],
filters: Sequence[BaseParam] | None = None,
order_by: BaseParam | None = None,
offset: BaseParam | None = None,
limit: BaseParam | None = None,
session: Session = NEW_SESSION,
return_total_entries: bool = True,
) -> Select:
) -> tuple[Select, int | None]:
base_select = apply_filters_to_select(
select,
filters,
base_select=select,
filters=filters,
)

total_entries = None
Expand All @@ -82,6 +116,6 @@ def paginated_select(
# readable_dags = get_auth_manager().get_permitted_dag_ids(user=g.user)
# dags_select = dags_select.where(DagModel.dag_id.in_(readable_dags))

base_select = apply_filters_to_select(base_select, [order_by, offset, limit])
base_select = apply_filters_to_select(base_select=base_select, filters=[order_by, offset, limit])

return base_select, total_entries
6 changes: 3 additions & 3 deletions airflow/api_fastapi/core_api/routes/public/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def get_assets(
limit=limit,
session=session,
)

assets = session.scalars(
assets_select.options(
subqueryload(AssetModel.consuming_dags), subqueryload(AssetModel.producing_tasks)
Expand Down Expand Up @@ -211,7 +212,7 @@ def get_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
dag_asset_queued_events_select, total_entries = paginated_select(select=query)
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
Expand Down Expand Up @@ -270,9 +271,8 @@ def get_dag_asset_queued_events(
.where(*where_clause)
)

dag_asset_queued_events_select, total_entries = paginated_select(select=query, filters=[])
dag_asset_queued_events_select, total_entries = paginated_select(select=query)
adrqs = session.execute(dag_asset_queued_events_select).all()

if not adrqs:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"Queue event with dag_id: `{dag_id}` was not found")

Expand Down
4 changes: 2 additions & 2 deletions airflow/api_fastapi/core_api/routes/public/backfills.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def list_backfills(
) -> BackfillCollectionResponse:
select_stmt, total_entries = paginated_select(
select=select(Backfill).where(Backfill.dag_id == dag_id),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
session=session,
)

backfills = session.scalars(select_stmt)

return BackfillCollectionResponse(
backfills=[BackfillResponse.model_validate(x, from_attributes=True) for x in backfills],
backfills=[BackfillResponse.model_validate(b, from_attributes=True) for b in backfills],
total_entries=total_entries,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,6 @@ def get_connections(
"""Get all connection entries."""
connection_select, total_entries = paginated_select(
select=select(Connection),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
3 changes: 1 addition & 2 deletions airflow/api_fastapi/core_api/routes/public/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,8 @@ def get_dag_runs(
limit=limit,
session=session,
)

dag_runs = session.scalars(dag_run_select)
return DAGRunCollectionResponse(
dag_runs=[DAGRunResponse.model_validate(dag_run, from_attributes=True) for dag_run in dag_runs],
dag_runs=[DAGRunResponse.model_validate(dr, from_attributes=True) for dr in dag_runs],
total_entries=total_entries,
)
6 changes: 1 addition & 5 deletions airflow/api_fastapi/core_api/routes/public/dag_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@ def list_dag_warnings(
limit=limit,
session=session,
)

dag_warnings = session.scalars(dag_warnings_select)

return DAGWarningCollectionResponse(
dag_warnings=[
DAGWarningResponse.model_validate(dag_warning, from_attributes=True)
for dag_warning in dag_warnings
],
dag_warnings=[DAGWarningResponse.model_validate(w, from_attributes=True) for w in dag_warnings],
total_entries=total_entries,
)
8 changes: 3 additions & 5 deletions airflow/api_fastapi/core_api/routes/public/dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_dag_tags(
session=session,
)
dag_tags = session.execute(dag_tags_select).scalars().all()
return DAGTagCollectionResponse(tags=[dag_tag for dag_tag in dag_tags], total_entries=total_entries)
return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries)
dstandish marked this conversation as resolved.
Show resolved Hide resolved


@dags_router.get(
Expand Down Expand Up @@ -259,6 +259,7 @@ def patch_dags(
status.HTTP_400_BAD_REQUEST, "Only `is_paused` field can be updated through the REST API"
)
else:
# todo: this is not used?
update_mask = ["is_paused"]

dags_select, total_entries = paginated_select(
Expand All @@ -269,11 +270,8 @@ def patch_dags(
limit=limit,
session=session,
)

dags = session.scalars(dags_select).all()

dags_to_update = {dag.dag_id for dag in dags}

session.execute(
update(DagModel)
.where(DagModel.dag_id.in_(dags_to_update))
Expand All @@ -282,7 +280,7 @@ def patch_dags(
)

return DAGCollectionResponse(
dags=[DAGResponse.model_validate(dag, from_attributes=True) for dag in dags],
dags=[DAGResponse.model_validate(d, from_attributes=True) for d in dags],
total_entries=total_entries,
)

Expand Down
9 changes: 1 addition & 8 deletions airflow/api_fastapi/core_api/routes/public/event_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def get_event_logs(
base_select = base_select.where(Log.dttm > after)
event_logs_select, total_entries = paginated_select(
select=base_select,
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand All @@ -135,12 +134,6 @@ def get_event_logs(
event_logs = session.scalars(event_logs_select)

return EventLogCollectionResponse(
event_logs=[
EventLogResponse.model_validate(
event_log,
from_attributes=True,
)
for event_log in event_logs
],
event_logs=[EventLogResponse.model_validate(e, from_attributes=True) for e in event_logs],
total_entries=total_entries,
)
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ def get_import_errors(
"""Get all import errors."""
import_errors_select, total_entries = paginated_select(
select=select(ParseImportError),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand All @@ -99,8 +98,6 @@ def get_import_errors(
import_errors = session.scalars(import_errors_select)

return ImportErrorCollectionResponse(
import_errors=[
ImportErrorResponse.model_validate(error, from_attributes=True) for error in import_errors
],
import_errors=[ImportErrorResponse.model_validate(i, from_attributes=True) for i in import_errors],
total_entries=total_entries,
)
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/pools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def get_pools(
"""Get all pools entries."""
pools_select, total_entries = paginated_select(
select=select(Pool),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down
14 changes: 2 additions & 12 deletions airflow/api_fastapi/core_api/routes/public/task_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def get_mapped_task_instances(
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
Expand Down Expand Up @@ -335,14 +334,9 @@ def get_task_instances(
limit=limit,
session=session,
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(task_instance, from_attributes=True)
for task_instance in task_instances
],
task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances],
total_entries=total_entries,
)

Expand Down Expand Up @@ -411,18 +405,14 @@ def get_task_instances_batch(
limit=limit,
session=session,
)

task_instance_select = task_instance_select.options(
joinedload(TI.rendered_task_instance_fields), joinedload(TI.task_instance_note)
)

task_instances = session.scalars(task_instance_select)

return TaskInstanceCollectionResponse(
task_instances=[
TaskInstanceResponse.model_validate(task_instance, from_attributes=True)
for task_instance in task_instances
],
task_instances=[TaskInstanceResponse.model_validate(t, from_attributes=True) for t in task_instances],
total_entries=total_entries,
)

Expand Down
1 change: 0 additions & 1 deletion airflow/api_fastapi/core_api/routes/public/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def get_variables(
"""Get all Variables entries."""
variable_select, total_entries = paginated_select(
select=select(Variable),
filters=[],
order_by=order_by,
offset=offset,
limit=limit,
Expand Down