Skip to content

Commit e46e208

Browse files
dstandishLee-W
authored andcommitted
Max active tasks to be evaluated per dag run (apache#42953)
This behavior change was accepted by lazy consensus here: https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l. Previously max_active_tasks was evaluated across all runs of a dag. Co-authored-by: Wei Lee <[email protected]>
1 parent f8b1f18 commit e46e208

File tree

3 files changed

+144
-137
lines changed

3 files changed

+144
-137
lines changed

airflow/jobs/scheduler_job_runner.py

+28-39
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
import sys
2525
import time
2626
from collections import Counter, defaultdict, deque
27-
from dataclasses import dataclass
2827
from datetime import timedelta
2928
from functools import lru_cache, partial
3029
from pathlib import Path
@@ -83,7 +82,6 @@
8382
from datetime import datetime
8483
from types import FrameType
8584

86-
from sqlalchemy.engine import Result
8785
from sqlalchemy.orm import Query, Session
8886

8987
from airflow.dag_processing.manager import DagFileProcessorAgent
@@ -99,7 +97,6 @@
9997
DM = DagModel
10098

10199

102-
@dataclass
103100
class ConcurrencyMap:
104101
"""
105102
Dataclass to represent concurrency maps.
@@ -109,17 +106,24 @@ class ConcurrencyMap:
109106
to # of task instances in the given state list in each DAG run.
110107
"""
111108

112-
dag_active_tasks_map: dict[str, int]
113-
task_concurrency_map: dict[tuple[str, str], int]
114-
task_dagrun_concurrency_map: dict[tuple[str, str, str], int]
115-
116-
@classmethod
117-
def from_concurrency_map(cls, mapping: dict[tuple[str, str, str], int]) -> ConcurrencyMap:
118-
instance = cls(Counter(), Counter(), Counter(mapping))
119-
for (d, _, t), c in mapping.items():
120-
instance.dag_active_tasks_map[d] += c
121-
instance.task_concurrency_map[(d, t)] += c
122-
return instance
109+
def __init__(self):
110+
self.dag_run_active_tasks_map: Counter[tuple[str, str]] = Counter()
111+
self.task_concurrency_map: Counter[tuple[str, str]] = Counter()
112+
self.task_dagrun_concurrency_map: Counter[tuple[str, str, str]] = Counter()
113+
114+
def load(self, session: Session) -> None:
115+
self.dag_run_active_tasks_map.clear()
116+
self.task_concurrency_map.clear()
117+
self.task_dagrun_concurrency_map.clear()
118+
query = session.execute(
119+
select(TI.dag_id, TI.task_id, TI.run_id, func.count("*"))
120+
.where(TI.state.in_(EXECUTION_STATES))
121+
.group_by(TI.task_id, TI.run_id, TI.dag_id)
122+
)
123+
for dag_id, task_id, run_id, c in query:
124+
self.dag_run_active_tasks_map[dag_id, run_id] += c
125+
self.task_concurrency_map[(dag_id, task_id)] += c
126+
self.task_dagrun_concurrency_map[(dag_id, run_id, task_id)] += c
123127

124128

125129
def _is_parent_process() -> bool:
@@ -258,22 +262,6 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None:
258262
executor.debug_dump()
259263
self.log.info("-" * 80)
260264

261-
def __get_concurrency_maps(self, states: Iterable[TaskInstanceState], session: Session) -> ConcurrencyMap:
262-
"""
263-
Get the concurrency maps.
264-
265-
:param states: List of states to query for
266-
:return: Concurrency map
267-
"""
268-
ti_concurrency_query: Result = session.execute(
269-
select(TI.task_id, TI.run_id, TI.dag_id, func.count("*"))
270-
.where(TI.state.in_(states))
271-
.group_by(TI.task_id, TI.run_id, TI.dag_id)
272-
)
273-
return ConcurrencyMap.from_concurrency_map(
274-
{(dag_id, run_id, task_id): count for task_id, run_id, dag_id, count in ti_concurrency_query}
275-
)
276-
277265
def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -> list[TI]:
278266
"""
279267
Find TIs that are ready for execution based on conditions.
@@ -326,7 +314,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
326314
starved_pools = {pool_name for pool_name, stats in pools.items() if stats["open"] <= 0}
327315

328316
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
329-
concurrency_map = self.__get_concurrency_maps(states=EXECUTION_STATES, session=session)
317+
concurrency_map = ConcurrencyMap()
318+
concurrency_map.load(session=session)
330319

331320
# Number of tasks that cannot be scheduled because of no open slot in pool
332321
num_starving_tasks_total = 0
@@ -465,22 +454,22 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
465454
# Check to make sure that the task max_active_tasks of the DAG hasn't been
466455
# reached.
467456
dag_id = task_instance.dag_id
468-
469-
current_active_tasks_per_dag = concurrency_map.dag_active_tasks_map[dag_id]
470-
max_active_tasks_per_dag_limit = task_instance.dag_model.max_active_tasks
457+
dag_run_key = (dag_id, task_instance.run_id)
458+
current_active_tasks_per_dag_run = concurrency_map.dag_run_active_tasks_map[dag_run_key]
459+
dag_max_active_tasks = task_instance.dag_model.max_active_tasks
471460
self.log.info(
472461
"DAG %s has %s/%s running and queued tasks",
473462
dag_id,
474-
current_active_tasks_per_dag,
475-
max_active_tasks_per_dag_limit,
463+
current_active_tasks_per_dag_run,
464+
dag_max_active_tasks,
476465
)
477-
if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit:
466+
if current_active_tasks_per_dag_run >= dag_max_active_tasks:
478467
self.log.info(
479468
"Not executing %s since the number of tasks running or queued "
480469
"from DAG %s is >= to the DAG's max_active_tasks limit of %s",
481470
task_instance,
482471
dag_id,
483-
max_active_tasks_per_dag_limit,
472+
dag_max_active_tasks,
484473
)
485474
starved_dags.add(dag_id)
486475
continue
@@ -571,7 +560,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
571560

572561
executable_tis.append(task_instance)
573562
open_slots -= task_instance.pool_slots
574-
concurrency_map.dag_active_tasks_map[dag_id] += 1
563+
concurrency_map.dag_run_active_tasks_map[dag_run_key] += 1
575564
concurrency_map.task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
576565
concurrency_map.task_dagrun_concurrency_map[
577566
(task_instance.dag_id, task_instance.run_id, task_instance.task_id)

newsfragments/42953.significant

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
DAG.max_active_runs now evaluated per-run
2+
3+
Previously, this was evaluated across all runs of the dag. This behavior change was passed by lazy consensus. Vote thread: https://lists.apache.org/thread/9o84d3yn934m32gtlpokpwtbbmtxj47l.

0 commit comments

Comments
 (0)