Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions airflow-core/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,12 @@ def requests_mock() -> RequestsMockFixture: ...
# time-machine
@pytest.fixture # type: ignore[no-redef]
def time_machine() -> TimeMachineFixture: ...


@pytest.fixture(autouse=True)
def _clear_in_process_api_cache():
"""Clear the cached InProcessExecutionAPI after each test to prevent state leakage."""
yield
supervisor_module = sys.modules.get("airflow.sdk.execution_time.supervisor")
if supervisor_module is not None:
supervisor_module.in_process_api_server.cache_clear()
7 changes: 6 additions & 1 deletion task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import atexit
import contextlib
import functools
import io
import logging
import os
Expand Down Expand Up @@ -1531,6 +1532,7 @@ def _send_new_log_fd(self, req_id: int) -> None:
child_logs.close() # Close this end now.


@functools.lru_cache(maxsize=1)
def in_process_api_server():
from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI

Expand Down Expand Up @@ -1692,15 +1694,18 @@ def start( # type: ignore[override]
@staticmethod
def _api_client(dag=None):
api = in_process_api_server()
from airflow.api_fastapi.common.dagbag import dag_bag_from_app

if dag is not None:
from airflow.api_fastapi.common.dagbag import dag_bag_from_app
from airflow.models.dagbag import DBDagBag

# This is needed since the Execution API server uses the DBDagBag in its "state".
# This `app.state.dag_bag` is used to get some Dag properties like `fail_fast`.
dag_bag = DBDagBag()

api.app.dependency_overrides[dag_bag_from_app] = lambda: dag_bag
else:
api.app.dependency_overrides.pop(dag_bag_from_app, None)
Comment thread
kaxil marked this conversation as resolved.

client = InProcessTestSupervisor._Client(
base_url=None, token="", dry_run=True, transport=api.transport
Expand Down
37 changes: 37 additions & 0 deletions task-sdk/tests/task_sdk/execution_time/test_supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
InProcessTestSupervisor,
_make_process_nondumpable,
_remote_logging_conn,
in_process_api_server,
process_log_messages_from_subprocess,
set_supervisor_comms,
supervise,
Expand Down Expand Up @@ -3297,3 +3298,39 @@ def test_nondumpable_noop_on_non_linux():
"""On non-Linux, _make_process_nondumpable returns without error."""

_make_process_nondumpable()


def test_in_process_api_server_caches_instance():
"""in_process_api_server() returns the same instance on repeated calls."""
in_process_api_server.cache_clear()
try:
first = in_process_api_server()
second = in_process_api_server()
assert first is second

in_process_api_server.cache_clear()
third = in_process_api_server()
assert third is not first
finally:
in_process_api_server.cache_clear()


def test_api_client_clears_dag_bag_override_when_dag_is_none():
"""_api_client(dag=None) removes stale dag_bag_from_app overrides set by a previous call."""
from unittest.mock import MagicMock

from airflow.api_fastapi.common.dagbag import dag_bag_from_app

in_process_api_server.cache_clear()
try:
# First call with a dag sets the override
mock_dag = MagicMock()
InProcessTestSupervisor._api_client(dag=mock_dag)
api = in_process_api_server()
assert dag_bag_from_app in api.app.dependency_overrides

# Second call with dag=None should remove it
InProcessTestSupervisor._api_client(dag=None)
assert dag_bag_from_app not in api.app.dependency_overrides
finally:
in_process_api_server.cache_clear()
Loading