24
24
import sys
25
25
import time
26
26
from collections import Counter , defaultdict , deque
27
- from dataclasses import dataclass
28
27
from datetime import timedelta
29
28
from functools import lru_cache , partial
30
29
from pathlib import Path
83
82
from datetime import datetime
84
83
from types import FrameType
85
84
86
- from sqlalchemy .engine import Result
87
85
from sqlalchemy .orm import Query , Session
88
86
89
87
from airflow .dag_processing .manager import DagFileProcessorAgent
99
97
DM = DagModel
100
98
101
99
102
- @dataclass
103
100
class ConcurrencyMap :
104
101
"""
105
102
Dataclass to represent concurrency maps.
@@ -109,17 +106,24 @@ class ConcurrencyMap:
109
106
to # of task instances in the given state list in each DAG run.
110
107
"""
111
108
112
- dag_active_tasks_map : dict [str , int ]
113
- task_concurrency_map : dict [tuple [str , str ], int ]
114
- task_dagrun_concurrency_map : dict [tuple [str , str , str ], int ]
115
-
116
- @classmethod
117
- def from_concurrency_map (cls , mapping : dict [tuple [str , str , str ], int ]) -> ConcurrencyMap :
118
- instance = cls (Counter (), Counter (), Counter (mapping ))
119
- for (d , _ , t ), c in mapping .items ():
120
- instance .dag_active_tasks_map [d ] += c
121
- instance .task_concurrency_map [(d , t )] += c
122
- return instance
109
+ def __init__ (self ):
110
+ self .dag_run_active_tasks_map : Counter [tuple [str , str ]] = Counter ()
111
+ self .task_concurrency_map : Counter [tuple [str , str ]] = Counter ()
112
+ self .task_dagrun_concurrency_map : Counter [tuple [str , str , str ]] = Counter ()
113
+
114
+ def load (self , session : Session ) -> None :
115
+ self .dag_run_active_tasks_map .clear ()
116
+ self .task_concurrency_map .clear ()
117
+ self .task_dagrun_concurrency_map .clear ()
118
+ query = session .execute (
119
+ select (TI .dag_id , TI .task_id , TI .run_id , func .count ("*" ))
120
+ .where (TI .state .in_ (EXECUTION_STATES ))
121
+ .group_by (TI .task_id , TI .run_id , TI .dag_id )
122
+ )
123
+ for dag_id , task_id , run_id , c in query :
124
+ self .dag_run_active_tasks_map [dag_id , run_id ] += c
125
+ self .task_concurrency_map [(dag_id , task_id )] += c
126
+ self .task_dagrun_concurrency_map [(dag_id , run_id , task_id )] += c
123
127
124
128
125
129
def _is_parent_process () -> bool :
@@ -258,22 +262,6 @@ def _debug_dump(self, signum: int, frame: FrameType | None) -> None:
258
262
executor .debug_dump ()
259
263
self .log .info ("-" * 80 )
260
264
261
- def __get_concurrency_maps (self , states : Iterable [TaskInstanceState ], session : Session ) -> ConcurrencyMap :
262
- """
263
- Get the concurrency maps.
264
-
265
- :param states: List of states to query for
266
- :return: Concurrency map
267
- """
268
- ti_concurrency_query : Result = session .execute (
269
- select (TI .task_id , TI .run_id , TI .dag_id , func .count ("*" ))
270
- .where (TI .state .in_ (states ))
271
- .group_by (TI .task_id , TI .run_id , TI .dag_id )
272
- )
273
- return ConcurrencyMap .from_concurrency_map (
274
- {(dag_id , run_id , task_id ): count for task_id , run_id , dag_id , count in ti_concurrency_query }
275
- )
276
-
277
265
def _executable_task_instances_to_queued (self , max_tis : int , session : Session ) -> list [TI ]:
278
266
"""
279
267
Find TIs that are ready for execution based on conditions.
@@ -326,7 +314,8 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
326
314
starved_pools = {pool_name for pool_name , stats in pools .items () if stats ["open" ] <= 0 }
327
315
328
316
# dag_id to # of running tasks and (dag_id, task_id) to # of running tasks.
329
- concurrency_map = self .__get_concurrency_maps (states = EXECUTION_STATES , session = session )
317
+ concurrency_map = ConcurrencyMap ()
318
+ concurrency_map .load (session = session )
330
319
331
320
# Number of tasks that cannot be scheduled because of no open slot in pool
332
321
num_starving_tasks_total = 0
@@ -465,22 +454,22 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
465
454
# Check to make sure that the task max_active_tasks of the DAG hasn't been
466
455
# reached.
467
456
dag_id = task_instance .dag_id
468
-
469
- current_active_tasks_per_dag = concurrency_map .dag_active_tasks_map [ dag_id ]
470
- max_active_tasks_per_dag_limit = task_instance .dag_model .max_active_tasks
457
+ dag_run_key = ( dag_id , task_instance . run_id )
458
+ current_active_tasks_per_dag_run = concurrency_map .dag_run_active_tasks_map [ dag_run_key ]
459
+ dag_max_active_tasks = task_instance .dag_model .max_active_tasks
471
460
self .log .info (
472
461
"DAG %s has %s/%s running and queued tasks" ,
473
462
dag_id ,
474
- current_active_tasks_per_dag ,
475
- max_active_tasks_per_dag_limit ,
463
+ current_active_tasks_per_dag_run ,
464
+ dag_max_active_tasks ,
476
465
)
477
- if current_active_tasks_per_dag >= max_active_tasks_per_dag_limit :
466
+ if current_active_tasks_per_dag_run >= dag_max_active_tasks :
478
467
self .log .info (
479
468
"Not executing %s since the number of tasks running or queued "
480
469
"from DAG %s is >= to the DAG's max_active_tasks limit of %s" ,
481
470
task_instance ,
482
471
dag_id ,
483
- max_active_tasks_per_dag_limit ,
472
+ dag_max_active_tasks ,
484
473
)
485
474
starved_dags .add (dag_id )
486
475
continue
@@ -571,7 +560,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session) -
571
560
572
561
executable_tis .append (task_instance )
573
562
open_slots -= task_instance .pool_slots
574
- concurrency_map .dag_active_tasks_map [ dag_id ] += 1
563
+ concurrency_map .dag_run_active_tasks_map [ dag_run_key ] += 1
575
564
concurrency_map .task_concurrency_map [(task_instance .dag_id , task_instance .task_id )] += 1
576
565
concurrency_map .task_dagrun_concurrency_map [
577
566
(task_instance .dag_id , task_instance .run_id , task_instance .task_id )
0 commit comments