Skip to content
4 changes: 2 additions & 2 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from flask import Response

from airflow.api_connexion.types import APIResponse
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.models import Variable, XCom
from airflow.serialization.serialized_objects import BaseSerialization

Expand All @@ -34,12 +33,14 @@

@functools.lru_cache()
def _initialize_map() -> dict[str, Callable]:
from airflow.dag_processing.manager import DagFileProcessorManager
from airflow.dag_processing.processor import DagFileProcessor
from airflow.models.dag import DagModel

functions: list[Callable] = [
DagFileProcessor.update_import_errors,
DagFileProcessor.manage_slas,
DagFileProcessorManager.deactivate_stale_dags,
DagModel.get_paused_dag_ids,
DagFileProcessorManager.clear_nonexistent_import_errors,
XCom.get_value,
Expand All @@ -62,7 +63,6 @@ def internal_airflow_api(body: dict[str, Any]) -> APIResponse:
return Response(response="Expected jsonrpc 2.0 request.", status=400)

methods_map = _initialize_map()

method_name = body.get("method")
if method_name not in methods_map:
log.error("Unrecognized method: %s.", method_name)
Expand Down
91 changes: 53 additions & 38 deletions airflow/dag_processing/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,52 +485,67 @@ def start(self):

return self._run_parsing_loop()

@provide_session
def _deactivate_stale_dags(self, session=None):
"""
Detects DAGs which are no longer present in files.

Deactivate them and remove them in the serialized_dag table
"""
def _scan_stale_dags(self):
"""Scan at fix internal DAGs which are no longer present in files."""
now = timezone.utcnow()
elapsed_time_since_refresh = (now - self.last_deactivate_stale_dags_time).total_seconds()
if elapsed_time_since_refresh > self.parsing_cleanup_interval:
last_parsed = {
fp: self.get_last_finish_time(fp) for fp in self.file_paths if self.get_last_finish_time(fp)
}
to_deactivate = set()
query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter(
DagModel.is_active
DagFileProcessorManager.deactivate_stale_dags(
last_parsed=last_parsed,
dag_directory=self.get_dag_directory(),
processor_timeout=self._processor_timeout,
)
if self.standalone_dag_processor:
query = query.filter(DagModel.processor_subdir == self.get_dag_directory())
dags_parsed = query.all()

for dag in dags_parsed:
# The largest valid difference between a DagFileStat's last_finished_time and a DAG's
# last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is
# no longer present in the file.
if (
dag.fileloc in last_parsed
and (dag.last_parsed_time + self._processor_timeout) < last_parsed[dag.fileloc]
):
self.log.info("DAG %s is missing and will be deactivated.", dag.dag_id)
to_deactivate.add(dag.dag_id)

if to_deactivate:
deactivated = (
session.query(DagModel)
.filter(DagModel.dag_id.in_(to_deactivate))
.update({DagModel.is_active: False}, synchronize_session="fetch")
)
if deactivated:
self.log.info("Deactivated %i DAGs which are no longer present in file.", deactivated)
self.last_deactivate_stale_dags_time = timezone.utcnow()

for dag_id in to_deactivate:
SerializedDagModel.remove_dag(dag_id)
self.log.info("Deleted DAG %s in serialized_dag table", dag_id)
@classmethod
@internal_api_call
@provide_session
def deactivate_stale_dags(
cls,
last_parsed: dict[str, datetime | None],
dag_directory: str,
processor_timeout: timedelta,
session: Session = NEW_SESSION,
):
"""
Detects DAGs which are no longer present in files.
Deactivate them and remove them in the serialized_dag table
"""
to_deactivate = set()
query = session.query(DagModel.dag_id, DagModel.fileloc, DagModel.last_parsed_time).filter(
DagModel.is_active
)
standalone_dag_processor = conf.getboolean("scheduler", "standalone_dag_processor")
if standalone_dag_processor:
query = query.filter(DagModel.processor_subdir == dag_directory)
dags_parsed = query.all()

for dag in dags_parsed:
# The largest valid difference between a DagFileStat's last_finished_time and a DAG's
# last_parsed_time is _processor_timeout. Longer than that indicates that the DAG is
# no longer present in the file.
if (
dag.fileloc in last_parsed
and (dag.last_parsed_time + processor_timeout) < last_parsed[dag.fileloc]
):
cls.logger().info("DAG %s is missing and will be deactivated.", dag.dag_id)
to_deactivate.add(dag.dag_id)

if to_deactivate:
deactivated = (
session.query(DagModel)
.filter(DagModel.dag_id.in_(to_deactivate))
.update({DagModel.is_active: False}, synchronize_session="fetch")
)
if deactivated:
cls.logger().info("Deactivated %i DAGs which are no longer present in file.", deactivated)

self.last_deactivate_stale_dags_time = timezone.utcnow()
for dag_id in to_deactivate:
SerializedDagModel.remove_dag(dag_id)
cls.logger().info("Deleted DAG %s in serialized_dag table", dag_id)

def _run_parsing_loop(self):
# In sync mode we want timeout=None -- wait forever until a message is received
Expand Down Expand Up @@ -595,7 +610,7 @@ def _run_parsing_loop(self):

if self.standalone_dag_processor:
self._fetch_callbacks(max_callbacks_per_loop)
self._deactivate_stale_dags()
self._scan_stale_dags()
DagWarning.purge_inactive_dag_warnings()
refreshed_dag_dir = self._refresh_dag_dir()

Expand Down
34 changes: 24 additions & 10 deletions tests/api_internal/endpoints/test_rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tests.test_utils.decorators import dont_initialize_flask_app_submodules

TEST_METHOD_NAME = "test_method"
TEST_METHOD_WITH_LOG_NAME = "test_method_with_log"

mock_test_method = mock.MagicMock()

Expand Down Expand Up @@ -58,14 +59,28 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
with mock.patch(
"airflow.api_internal.endpoints.rpc_api_endpoint._initialize_map"
) as mock_initialize_map:
mock_initialize_map.return_value = {TEST_METHOD_NAME: mock_test_method}
mock_initialize_map.return_value = {
TEST_METHOD_NAME: mock_test_method,
}
yield mock_initialize_map

@pytest.mark.parametrize(
"input_data, method_result, method_params, expected_code",
"input_data, method_result, method_params, expected_mock, expected_code",
[
({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, "test_me", None, 200),
({"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""}, None, None, 200),
(
{"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
"test_me",
{},
mock_test_method,
200,
),
(
{"jsonrpc": "2.0", "method": TEST_METHOD_NAME, "params": ""},
None,
{},
mock_test_method,
200,
),
(
{
"jsonrpc": "2.0",
Expand All @@ -74,13 +89,14 @@ def setup_attrs(self, minimal_app_for_internal_api: Flask) -> Generator:
},
("dag_id_15", "fake-task", 1),
{"dag_id": 15, "task_id": "fake-task"},
mock_test_method,
200,
),
],
)
def test_method(self, input_data, method_result, method_params, expected_code):
def test_method(self, input_data, method_result, method_params, expected_mock, expected_code):
if method_result:
mock_test_method.return_value = method_result
expected_mock.return_value = method_result

response = self.client.post(
"/internal_api/v1/rpcapi",
Expand All @@ -91,10 +107,8 @@ def test_method(self, input_data, method_result, method_params, expected_code):
if method_result:
response_data = BaseSerialization.deserialize(json.loads(response.data))
assert response_data == method_result
if method_params:
mock_test_method.assert_called_once_with(**method_params)
else:
mock_test_method.assert_called_once()

expected_mock.assert_called_once_with(**method_params)

def test_method_with_exception(self):
mock_test_method.side_effect = ValueError("Error!!!")
Expand Down
5 changes: 3 additions & 2 deletions tests/api_internal/test_internal_api_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def fake_method() -> str:

@staticmethod
@internal_api_call
def fake_method_with_params(dag_id: str, task_id: int) -> str:
def fake_method_with_params(dag_id: str, task_id: int, session) -> str:
return f"local-call-with-params-{dag_id}-{task_id}"

@conf_vars(
Expand Down Expand Up @@ -124,7 +124,8 @@ def test_remote_call_with_params(self, mock_requests):

mock_requests.post.return_value = response

result = TestInternalApiCall.fake_method_with_params("fake-dag", task_id=123)
result = TestInternalApiCall.fake_method_with_params("fake-dag", task_id=123, session="session")

assert result == "remote-call"
expected_data = json.dumps(
{
Expand Down
8 changes: 4 additions & 4 deletions tests/dag_processing/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def test_recently_modified_file_is_parsed_with_mtime_mode(
> (freezed_base_time - manager.get_last_finish_time("file_1.py")).total_seconds()
)

def test_deactivate_stale_dags(self):
def test_scan_stale_dags(self):
"""
Ensure that DAGs are marked inactive when the file is parsed but the
DagModel.last_parsed_time is not updated.
Expand Down Expand Up @@ -545,7 +545,7 @@ def test_deactivate_stale_dags(self):
)
assert serialized_dag_count == 1

manager._deactivate_stale_dags()
manager._scan_stale_dags()

active_dag_count = (
session.query(func.count(DagModel.dag_id))
Expand All @@ -567,7 +567,7 @@ def test_deactivate_stale_dags(self):
("scheduler", "standalone_dag_processor"): "True",
}
)
def test_deactivate_stale_dags_standalone_mode(self):
def test_scan_stale_dags_standalone_mode(self):
"""
Ensure only dags from current dag_directory are updated
"""
Expand Down Expand Up @@ -612,7 +612,7 @@ def test_deactivate_stale_dags_standalone_mode(self):
active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
assert active_dag_count == 2

manager._deactivate_stale_dags()
manager._scan_stale_dags()

active_dag_count = session.query(func.count(DagModel.dag_id)).filter(DagModel.is_active).scalar()
assert active_dag_count == 1
Expand Down