26
26
27
27
from airflow .models .dagrun import DagRun
28
28
from airflow .models .taskinstance import TaskInstance
29
- from airflow .operators .subdag import SubDagOperator
30
29
from airflow .utils import timezone
31
30
from airflow .utils .helpers import exactly_one
32
31
from airflow .utils .session import NEW_SESSION , provide_session
33
32
from airflow .utils .state import DagRunState , State , TaskInstanceState
34
- from airflow .utils .types import DagRunType
35
33
36
34
if TYPE_CHECKING :
37
35
from datetime import datetime
40
38
41
39
from airflow .models .dag import DAG
42
40
from airflow .models .operator import Operator
41
+ from airflow .utils .types import DagRunType
43
42
44
43
45
44
class _DagRunInfo (NamedTuple ):
@@ -101,14 +100,14 @@ def set_state(
101
100
Can set state for future tasks (calculated from run_id) and retroactively
102
101
for past tasks. Will verify integrity of past dag runs in order to create
103
102
tasks that did not exist. It will not create dag runs that are missing
104
- on the schedule (but it will, as for subdag, dag runs if needed) .
103
+ on the schedule.
105
104
106
105
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
107
106
``task.dag`` needs to be set
108
107
:param run_id: the run_id of the dagrun to start looking from
109
108
:param execution_date: the execution date from which to start looking (deprecated)
110
109
:param upstream: Mark all parents (upstream tasks)
111
- :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
110
+ :param downstream: Mark all siblings (downstream tasks) of task_id
112
111
:param future: Mark all future tasks on the interval of the dag up until
113
112
last execution date.
114
113
:param past: Retroactively mark all tasks starting from start_date of the DAG
@@ -140,54 +139,20 @@ def set_state(
140
139
141
140
dag_run_ids = get_run_ids (dag , run_id , future , past , session = session )
142
141
task_id_map_index_list = list (find_task_relatives (tasks , downstream , upstream ))
143
- task_ids = [task_id if isinstance (task_id , str ) else task_id [0 ] for task_id in task_id_map_index_list ]
144
-
145
- confirmed_infos = list (_iter_existing_dag_run_infos (dag , dag_run_ids , session = session ))
146
- confirmed_dates = [info .logical_date for info in confirmed_infos ]
147
-
148
- sub_dag_run_ids = (
149
- list (
150
- _iter_subdag_run_ids (dag , session , DagRunState (state ), task_ids , commit , confirmed_infos ),
151
- )
152
- if not state == TaskInstanceState .SKIPPED
153
- else []
154
- )
155
-
156
142
# now look for the task instances that are affected
157
143
158
144
qry_dag = get_all_dag_task_query (dag , session , state , task_id_map_index_list , dag_run_ids )
159
145
160
146
if commit :
161
147
tis_altered = session .scalars (qry_dag .with_for_update ()).all ()
162
- if sub_dag_run_ids :
163
- qry_sub_dag = all_subdag_tasks_query (sub_dag_run_ids , session , state , confirmed_dates )
164
- tis_altered += session .scalars (qry_sub_dag .with_for_update ()).all ()
165
148
for task_instance in tis_altered :
166
149
task_instance .set_state (state , session = session )
167
150
session .flush ()
168
151
else :
169
152
tis_altered = session .scalars (qry_dag ).all ()
170
- if sub_dag_run_ids :
171
- qry_sub_dag = all_subdag_tasks_query (sub_dag_run_ids , session , state , confirmed_dates )
172
- tis_altered += session .scalars (qry_sub_dag ).all ()
173
153
return tis_altered
174
154
175
155
176
- def all_subdag_tasks_query (
177
- sub_dag_run_ids : list [str ],
178
- session : SASession ,
179
- state : TaskInstanceState ,
180
- confirmed_dates : Iterable [datetime ],
181
- ):
182
- """Get *all* tasks of the sub dags."""
183
- qry_sub_dag = (
184
- select (TaskInstance )
185
- .where (TaskInstance .dag_id .in_ (sub_dag_run_ids ), TaskInstance .execution_date .in_ (confirmed_dates ))
186
- .where (or_ (TaskInstance .state .is_ (None ), TaskInstance .state != state ))
187
- )
188
- return qry_sub_dag
189
-
190
-
191
156
def get_all_dag_task_query (
192
157
dag : DAG ,
193
158
session : SASession ,
@@ -208,71 +173,6 @@ def get_all_dag_task_query(
208
173
return qry_dag
209
174
210
175
211
- def _iter_subdag_run_ids (
212
- dag : DAG ,
213
- session : SASession ,
214
- state : DagRunState ,
215
- task_ids : list [str ],
216
- commit : bool ,
217
- confirmed_infos : Iterable [_DagRunInfo ],
218
- ) -> Iterator [str ]:
219
- """
220
- Go through subdag operators and create dag runs.
221
-
222
- We only work within the scope of the subdag. A subdag does not propagate to
223
- its parent DAG, but parent propagates to subdags.
224
- """
225
- dags = [dag ]
226
- while dags :
227
- current_dag = dags .pop ()
228
- for task_id in task_ids :
229
- if not current_dag .has_task (task_id ):
230
- continue
231
-
232
- current_task = current_dag .get_task (task_id )
233
- if isinstance (current_task , SubDagOperator ) or current_task .task_type == "SubDagOperator" :
234
- # this works as a kind of integrity check
235
- # it creates missing dag runs for subdag operators,
236
- # maybe this should be moved to dagrun.verify_integrity
237
- if TYPE_CHECKING :
238
- assert current_task .subdag
239
- dag_runs = _create_dagruns (
240
- current_task .subdag ,
241
- infos = confirmed_infos ,
242
- state = DagRunState .RUNNING ,
243
- run_type = DagRunType .BACKFILL_JOB ,
244
- )
245
-
246
- verify_dagruns (dag_runs , commit , state , session , current_task )
247
-
248
- dags .append (current_task .subdag )
249
- yield current_task .subdag .dag_id
250
-
251
-
252
- def verify_dagruns (
253
- dag_runs : Iterable [DagRun ],
254
- commit : bool ,
255
- state : DagRunState ,
256
- session : SASession ,
257
- current_task : Operator ,
258
- ):
259
- """
260
- Verify integrity of dag_runs.
261
-
262
- :param dag_runs: dag runs to verify
263
- :param commit: whether dag runs state should be updated
264
- :param state: state of the dag_run to set if commit is True
265
- :param session: session to use
266
- :param current_task: current task
267
- """
268
- for dag_run in dag_runs :
269
- dag_run .dag = current_task .subdag
270
- dag_run .verify_integrity ()
271
- if commit :
272
- dag_run .state = state
273
- session .merge (dag_run )
274
-
275
-
276
176
def _iter_existing_dag_run_infos (dag : DAG , run_ids : list [str ], session : SASession ) -> Iterator [_DagRunInfo ]:
277
177
for dag_run in DagRun .find (dag_id = dag .dag_id , run_id = run_ids , session = session ):
278
178
dag_run .dag = dag
0 commit comments