diff --git a/airflow-core/tests/conftest.py b/airflow-core/tests/conftest.py index bffa1c61e8912..b4e3e66b6bdaf 100644 --- a/airflow-core/tests/conftest.py +++ b/airflow-core/tests/conftest.py @@ -161,3 +161,12 @@ def requests_mock() -> RequestsMockFixture: ... # time-machine @pytest.fixture # type: ignore[no-redef] def time_machine() -> TimeMachineFixture: ... + + +@pytest.fixture(autouse=True) +def _clear_in_process_api_cache(): + """Clear the cached InProcessExecutionAPI after each test to prevent state leakage.""" + yield + supervisor_module = sys.modules.get("airflow.sdk.execution_time.supervisor") + if supervisor_module is not None: + supervisor_module.in_process_api_server.cache_clear() diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index 8675305dde967..b34d17e0bc3b6 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -21,6 +21,7 @@ import atexit import contextlib +import functools import io import logging import os @@ -1531,6 +1532,7 @@ def _send_new_log_fd(self, req_id: int) -> None: child_logs.close() # Close this end now. +@functools.lru_cache(maxsize=1) def in_process_api_server(): from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI @@ -1692,8 +1694,9 @@ def start( # type: ignore[override] @staticmethod def _api_client(dag=None): api = in_process_api_server() + from airflow.api_fastapi.common.dagbag import dag_bag_from_app + if dag is not None: - from airflow.api_fastapi.common.dagbag import dag_bag_from_app from airflow.models.dagbag import DBDagBag # This is needed since the Execution API server uses the DBDagBag in its "state". @@ -1701,6 +1704,8 @@ def _api_client(dag=None): dag_bag = DBDagBag() api.app.dependency_overrides[dag_bag_from_app] = lambda: dag_bag + else: + api.app.dependency_overrides.pop(dag_bag_from_app, None) client = InProcessTestSupervisor._Client( base_url=None, token="", dry_run=True, transport=api.transport diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index e88b5c794cfda..33271d1f8f3f6 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -141,6 +141,7 @@ InProcessTestSupervisor, _make_process_nondumpable, _remote_logging_conn, + in_process_api_server, process_log_messages_from_subprocess, set_supervisor_comms, supervise, @@ -3297,3 +3298,39 @@ def test_nondumpable_noop_on_non_linux(): """On non-Linux, _make_process_nondumpable returns without error.""" _make_process_nondumpable() + + +def test_in_process_api_server_caches_instance(): + """in_process_api_server() returns the same instance on repeated calls.""" + in_process_api_server.cache_clear() + try: + first = in_process_api_server() + second = in_process_api_server() + assert first is second + + in_process_api_server.cache_clear() + third = in_process_api_server() + assert third is not first + finally: + in_process_api_server.cache_clear() + + +def test_api_client_clears_dag_bag_override_when_dag_is_none(): + """_api_client(dag=None) removes stale dag_bag_from_app overrides set by a previous call.""" + from unittest.mock import MagicMock + + from airflow.api_fastapi.common.dagbag import dag_bag_from_app + + in_process_api_server.cache_clear() + try: + # First call with a dag sets the override + mock_dag = MagicMock() + InProcessTestSupervisor._api_client(dag=mock_dag) + api = in_process_api_server() + assert dag_bag_from_app in api.app.dependency_overrides + + # Second call with dag=None should remove it + InProcessTestSupervisor._api_client(dag=None) + assert dag_bag_from_app not in api.app.dependency_overrides + finally: + in_process_api_server.cache_clear()