From 95536d10a65d9c1c18d1802e82be9b1e2ef42083 Mon Sep 17 00:00:00 2001 From: pierrejeambrun Date: Fri, 31 Oct 2025 19:47:12 +0100 Subject: [PATCH] Add number of queries guard in public task instances list endpoints --- .../api_fastapi/common/db/task_instances.py | 40 +++++++ .../core_api/routes/public/task_instances.py | 17 ++- .../routes/public/test_task_instances.py | 102 +++++++++++++----- 3 files changed, 121 insertions(+), 38 deletions(-) create mode 100644 airflow-core/src/airflow/api_fastapi/common/db/task_instances.py diff --git a/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py new file mode 100644 index 0000000000000..423e57f2316ae --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/common/db/task_instances.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from sqlalchemy.orm import joinedload +from sqlalchemy.orm.interfaces import LoaderOption + +from airflow.models import Base +from airflow.models.dag_version import DagVersion +from airflow.models.dagrun import DagRun +from airflow.models.taskinstance import TaskInstance + + +def eager_load_TI_and_TIH_for_validation(orm_model: Base | None = None) -> tuple[LoaderOption, ...]: + """Construct the eager loading options necessary for both TaskInstanceResponse and TaskInstanceHistoryResponse objects.""" + if orm_model is None: + orm_model = TaskInstance + + options: tuple[LoaderOption, ...] = ( + joinedload(orm_model.dag_version).joinedload(DagVersion.bundle), + joinedload(orm_model.dag_run).options(joinedload(DagRun.dag_model)), + ) + if orm_model is TaskInstance: + options += (joinedload(orm_model.task_instance_note),) + return options diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 740dc9868308c..25263a0e2376b 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -33,6 +33,7 @@ get_latest_version_of_dag, ) from airflow.api_fastapi.common.db.common import SessionDep, paginated_select +from airflow.api_fastapi.common.db.task_instances import eager_load_TI_and_TIH_for_validation from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, FilterParam, @@ -193,8 +194,7 @@ def get_mapped_task_instances( select(TI) .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id, TI.map_index >= 0) .join(TI.dag_run) - .options(joinedload(TI.dag_version)) - .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model))) + .options(*eager_load_TI_and_TIH_for_validation()) ) # 0 can mean a mapped TI that expanded to an empty list, so it is not an automatic 404 unfiltered_total_count = get_query_count(query, session=session) @@ -324,8 +324,7 @@ def _query(orm_object: Base) -> Select: orm_object.task_id == task_id, orm_object.map_index == map_index, ) - .options(joinedload(orm_object.dag_version)) - .options(joinedload(orm_object.dag_run).options(joinedload(DagRun.dag_model))) + .options(*eager_load_TI_and_TIH_for_validation(orm_object)) .options(joinedload(orm_object.hitl_detail)) ) return query @@ -467,11 +466,7 @@ def get_task_instances( """ dag_run = None query = ( - select(TI) - .join(TI.dag_run) - .outerjoin(TI.dag_version) - .options(joinedload(TI.dag_version)) - .options(joinedload(TI.dag_run).options(joinedload(DagRun.dag_model))) + select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation()) ) if dag_run_id != "~": dag_run = session.scalar(select(DagRun).filter_by(run_id=dag_run_id)) @@ -597,7 +592,9 @@ def get_task_instances_batch( TI, ).set_value([body.order_by] if body.order_by else None) - query = select(TI).join(TI.dag_run).outerjoin(TI.dag_version) + query = ( + select(TI).join(TI.dag_run).outerjoin(TI.dag_version).options(*eager_load_TI_and_TIH_for_validation()) + ) task_instance_select, total_entries = paginated_select( statement=query, filters=[ diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 6a49c9fb63b41..f1c6564e242c0 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -47,6 +47,7 @@ from airflow.utils.types import DagRunType from tests_common.test_utils.api_fastapi import _check_task_instance_note +from tests_common.test_utils.asserts import assert_queries_count from tests_common.test_utils.db import ( clear_db_runs, clear_rendered_ti_fields, @@ -762,9 +763,10 @@ def test_should_respond_404(self, test_client): assert response.json() == {"detail": "The Dag with ID: `mapped_tis` was not found"} def test_should_respond_200(self, one_task_with_many_mapped_tis, test_client): - response = test_client.get( - "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", + ) assert response.status_code == 200 assert response.json()["total_entries"] == 110 @@ -803,10 +805,11 @@ def test_offset_limit(self, test_client, one_task_with_many_mapped_tis): def test_mapped_instances_order( self, test_client, session, params, expected_map_indexes, one_task_with_many_mapped_tis ): - response = test_client.get( - "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params=params, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", + params=params, + ) assert response.status_code == 200 body = response.json() @@ -834,10 +837,11 @@ def test_rendered_map_index_order( session.commit() - response = test_client.get( - "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", - params=params, - ) + with assert_queries_count(4): + response = test_client.get( + "/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", + params=params, + ) assert response.status_code == 200 body = response.json() assert body["total_entries"] == 110 @@ -935,7 +939,7 @@ def test_should_raise_404_not_found_for_nonexistent_task( class TestGetTaskInstances(TestTaskInstanceEndpoint): @pytest.mark.parametrize( - "task_instances, update_extras, url, params, expected_ti", + "task_instances, update_extras, url, params, expected_ti, expected_queries_number", [ pytest.param( [ @@ -947,6 +951,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/~/taskInstances", {"logical_date_lte": DEFAULT_DATETIME_1}, 1, + 5, id="test logical date filter", ), pytest.param( @@ -959,6 +964,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/~/taskInstances", {"start_date_gte": DEFAULT_DATETIME_1, "start_date_lte": DEFAULT_DATETIME_STR_2}, 2, + 5, id="test start date filter", ), pytest.param( @@ -974,6 +980,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "start_date_lt": DEFAULT_DATETIME_STR_2, }, 1, + 5, id="test start date gt and lt filter", ), pytest.param( @@ -986,6 +993,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/~/taskInstances?", {"end_date_gte": DEFAULT_DATETIME_1, "end_date_lte": DEFAULT_DATETIME_STR_2}, 2, + 5, id="test end date filter", ), pytest.param( @@ -1001,6 +1009,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "end_date_lt": (DEFAULT_DATETIME_2 + dt.timedelta(hours=1)).isoformat(), }, 1, + 5, id="test end date gt and lt filter", ), pytest.param( @@ -1013,6 +1022,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances", {"duration_gte": 100, "duration_lte": 200}, 3, + 7, id="test duration filter", ), pytest.param( @@ -1025,6 +1035,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"duration_gte": 100, "duration_lte": 200}, 3, + 3, id="test duration filter ~", ), pytest.param( @@ -1037,6 +1048,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"duration_gt": 100, "duration_lt": 200}, 1, + 3, id="test duration gt and lt filter ~", ), pytest.param( @@ -1050,6 +1062,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {"state": ["running", "queued", "none"]}, 3, + 7, id="test state filter", ), pytest.param( @@ -1063,6 +1076,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {"state": ["no_status"]}, 1, + 7, id="test no_status state filter", ), pytest.param( @@ -1076,6 +1090,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {}, 4, + 7, id="test null states with no filter", ), pytest.param( @@ -1084,6 +1099,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances", {"start_date_gte": DEFAULT_DATETIME_STR_1}, 1, + 7, id="test start_date coalesce with null", ), pytest.param( @@ -1096,6 +1112,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {"pool": ["test_pool_1", "test_pool_2"]}, 2, + 7, id="test pool filter", ), pytest.param( @@ -1108,6 +1125,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"pool": ["test_pool_1", "test_pool_2"]}, 2, + 3, id="test pool filter ~", ), pytest.param( @@ -1120,6 +1138,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances", {"queue": ["test_queue_1", "test_queue_2"]}, 2, + 7, id="test queue filter", ), pytest.param( @@ -1132,6 +1151,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"queue": ["test_queue_1", "test_queue_2"]}, 2, + 3, id="test queue filter ~", ), pytest.param( @@ -1144,6 +1164,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {"executor": ["test_exec_1", "test_exec_2"]}, 2, + 7, id="test_executor_filter", ), pytest.param( @@ -1156,6 +1177,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"executor": ["test_exec_1", "test_exec_2"]}, 2, + 3, id="test executor filter ~", ), pytest.param( @@ -1168,6 +1190,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/~/dagRuns/~/taskInstances"), {"task_display_name_pattern": "task_name"}, 2, + 3, id="test task_display_name_pattern filter", ), pytest.param( @@ -1180,6 +1203,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/~/dagRuns/~/taskInstances"), {"task_id": "task_match_id_2"}, 1, + 3, id="test task_id filter", ), pytest.param( @@ -1190,6 +1214,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/~/dagRuns/~/taskInstances"), {"version_number": [2]}, 2, + 3, id="test version number filter", ), pytest.param( @@ -1201,6 +1226,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): {"version_number": [1, 2, 3]}, 7, # apart from the TIs in the fixture, we also get one from # the create_task_instances method + 3, id="test multiple version numbers filter", ), pytest.param( @@ -1216,6 +1242,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"), {"try_number": [0, 1]}, 5, + 7, id="test_try_number_filter", ), pytest.param( @@ -1234,6 +1261,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/~/dagRuns/~/taskInstances"), {"operator": ["FirstOperator", "SecondOperator"]}, 5, + 3, id="test operator type filter filter", ), pytest.param( @@ -1251,6 +1279,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): ("/dags/~/dagRuns/~/taskInstances"), {"map_index": [0, 1]}, 2, + 3, id="test map_index filter", ), pytest.param( @@ -1259,6 +1288,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "example_python_operator"}, 9, # Based on test failure - example_python_operator creates 9 task instances + 3, id="test dag_id_pattern exact match", ), pytest.param( @@ -1267,6 +1297,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "example_%"}, 17, # Based on test failure - both DAGs together create 17 task instances + 3, id="test dag_id_pattern wildcard prefix", ), pytest.param( @@ -1275,6 +1306,7 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "%skip%"}, 8, # Based on test failure - example_skip_dag creates 8 task instances + 3, id="test dag_id_pattern wildcard contains", ), pytest.param( @@ -1283,13 +1315,22 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint): "/dags/~/dagRuns/~/taskInstances", {"dag_id_pattern": "nonexistent"}, 0, + 3, id="test dag_id_pattern no match", ), ], ) @pytest.mark.usefixtures("make_dag_with_multiple_versions") def test_should_respond_200( - self, test_client, task_instances, update_extras, url, params, expected_ti, session + self, + test_client, + task_instances, + update_extras, + url, + params, + expected_ti, + expected_queries_number, + session, ): # Special handling for dag_id_pattern tests that require multiple DAGs if task_instances == "dag_id_pattern_test": @@ -1307,7 +1348,8 @@ def test_should_respond_200( with mock.patch("airflow.models.dag_version.DagBundlesManager") as dag_bundle_manager_mock: dag_bundle_manager_mock.return_value.view_url.return_value = "some_url" # Mock DagBundlesManager to avoid checking if dags-folder bundle is configured - response = test_client.get(url, params=params) + with assert_queries_count(expected_queries_number): + response = test_client.get(url, params=params) if params == {"task_id_pattern": "task_match_id"}: import pprint @@ -1664,10 +1706,11 @@ def test_should_respond_200( update_extras=update_extras, task_instances=task_instances, ) - response = test_client.post( - "/dags/~/dagRuns/~/taskInstances/list", - json=payload, - ) + with assert_queries_count(4): + response = test_client.post( + "/dags/~/dagRuns/~/taskInstances/list", + json=payload, + ) body = response.json() assert response.status_code == 200, body assert expected_ti_count == body["total_entries"] @@ -3333,9 +3376,10 @@ def test_should_respond_200(self, test_client, session): self.create_task_instances( session=session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True ) - response = test_client.get( - "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries" - ) + with assert_queries_count(3): + response = test_client.get( + "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries" + ) assert response.status_code == 200 assert response.json()["total_entries"] == 2 # The task instance and its history assert len(response.json()["task_instances"]) == 2 @@ -3425,9 +3469,10 @@ def test_should_respond_200_with_hitl( TaskInstanceHistory.record_ti(ti, session=session) session.flush() - response = test_client.get( - f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries", - ) + with assert_queries_count(3): + response = test_client.get( + f"/dags/{ti.dag_id}/dagRuns/{ti.run_id}/taskInstances/{ti.task_id}/tries", + ) assert response.status_code == 200 assert response.json() == { "task_instances": [ @@ -3569,10 +3614,11 @@ def test_mapped_task_should_respond_200(self, test_client, session): # in each loop, we should get the right mapped TI back for map_index in (1, 2): # Get the info from TIHistory: try_number 1, try_number 2 is TI table(latest) - response = test_client.get( - "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" - f"/print_the_context/{map_index}/tries", - ) + with assert_queries_count(3): + response = test_client.get( + "/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances" + f"/print_the_context/{map_index}/tries", + ) assert response.status_code == 200 assert ( response.json()["total_entries"] == 2