diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py b/airflow/api_connexion/endpoints/dag_run_endpoint.py index a83ca223b07ac..28d80757f4ef0 100644 --- a/airflow/api_connexion/endpoints/dag_run_endpoint.py +++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py @@ -29,11 +29,16 @@ from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters from airflow.api_connexion.schemas.dag_run_schema import ( DAGRunCollection, + clear_dagrun_form_schema, dagrun_collection_schema, dagrun_schema, dagruns_batch_form_schema, set_dagrun_state_form_schema, ) +from airflow.api_connexion.schemas.task_instance_schema import ( + TaskInstanceReferenceCollection, + task_instance_reference_collection_schema, +) from airflow.api_connexion.types import APIResponse from airflow.models import DagModel, DagRun from airflow.security import permissions @@ -312,3 +317,54 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, commit=True) dag_run = session.query(DagRun).get(dag_run.id) return dagrun_schema.dump(dag_run) + + +@security.requires_access( + [ + (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG), + (permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG_RUN), + ], +) +@provide_session +def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse: + """Clear a dag run.""" + dag_run: Optional[DagRun] = ( + session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none() + ) + if dag_run is None: + error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}' + raise NotFound(error_message) + try: + post_body = clear_dagrun_form_schema.load(request.json) + except ValidationError as err: + raise BadRequest(detail=str(err)) + + dry_run = post_body.get('dry_run', False) + dag = current_app.dag_bag.get_dag(dag_id) + start_date = dag_run.logical_date + end_date = dag_run.logical_date + + if dry_run: + task_instances = dag.clear( + start_date=start_date, + end_date=end_date, + task_ids=None, + include_subdags=True, + include_parentdag=True, + only_failed=False, + dry_run=True, + ) + return task_instance_reference_collection_schema.dump( + TaskInstanceReferenceCollection(task_instances=task_instances) + ) + else: + dag.clear( + start_date=start_date, + end_date=end_date, + task_ids=None, + include_subdags=True, + include_parentdag=True, + only_failed=False, + ) + dag_run.refresh_from_db() + return dagrun_schema.dump(dag_run) diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 9e99db032df56..f6395da24ea67 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -781,6 +781,43 @@ paths: '404': $ref: '#/components/responses/NotFound' + /dags/{dag_id}/dagRuns/{dag_run_id}/clear: + parameters: + - $ref: '#/components/parameters/DAGID' + - $ref: '#/components/parameters/DAGRunID' + + post: + summary: Clear a DAG run + description: | + Clear a DAG run. + + *New in version 2.4.0* + x-openapi-router-controller: airflow.api_connexion.endpoints.dag_run_endpoint + operationId: clear_dag_run + tags: [DAGRun] + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/ClearDagRun' + + responses: + '200': + description: Success. + content: + application/json: + schema: + $ref: '#/components/schemas/DAGRun' + '400': + $ref: '#/components/responses/BadRequest' + '401': + $ref: '#/components/responses/Unauthenticated' + '403': + $ref: '#/components/responses/PermissionDenied' + '404': + $ref: '#/components/responses/NotFound' + /eventLogs: get: summary: List log entries @@ -3310,6 +3347,16 @@ components: nullable: true # Form + ClearDagRun: + type: object + properties: + dry_run: + description: | + If set, don't actually run this operation. The response will contain a list of task instances + planned to be cleaned, but not modified in any way. + type: boolean + default: true + ClearTaskInstance: type: object properties: diff --git a/airflow/api_connexion/schemas/dag_run_schema.py b/airflow/api_connexion/schemas/dag_run_schema.py index 44f6eda496df2..540d438294737 100644 --- a/airflow/api_connexion/schemas/dag_run_schema.py +++ b/airflow/api_connexion/schemas/dag_run_schema.py @@ -115,6 +115,12 @@ class SetDagRunStateFormSchema(Schema): state = DagStateField(validate=validate.OneOf([DagRunState.SUCCESS.value, DagRunState.FAILED.value])) +class ClearDagRunStateFormSchema(Schema): + """Schema for handling the request of clearing a DAG run""" + + dry_run = fields.Boolean(load_default=True) + + class DAGRunCollection(NamedTuple): """List of DAGRuns with metadata""" @@ -154,4 +160,5 @@ class Meta: dagrun_schema = DAGRunSchema() dagrun_collection_schema = DAGRunCollectionSchema() set_dagrun_state_form_schema = SetDagRunStateFormSchema() +clear_dagrun_form_schema = ClearDagRunStateFormSchema() dagruns_batch_form_schema = DagRunsBatchFormSchema() diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index d2547587dd4f6..3e6ff09a542d4 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -1367,3 +1367,117 @@ def test_should_respond_404(self): environ_overrides={"REMOTE_USER": "test"}, ) assert response.status_code == 404 + + +class TestClearDagRun(TestDagRunEndpoint): + def test_should_respond_200(self, dag_maker, session): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + task = EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() + + request_json = {"dry_run": False} + + response = self.client.post( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + + dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() + assert response.status_code == 200 + assert response.json == { + "conf": {}, + "dag_id": dag_id, + "dag_run_id": dag_run_id, + "end_date": None, + "execution_date": dr.execution_date.isoformat(), + "external_trigger": False, + "logical_date": dr.logical_date.isoformat(), + "start_date": dr.logical_date.isoformat(), + "state": "queued", + "data_interval_start": dr.data_interval_start.isoformat(), + "data_interval_end": dr.data_interval_end.isoformat(), + "last_scheduling_decision": None, + "run_type": dr.run_type, + } + + ti.refresh_from_db() + assert ti.state is None + + def test_dry_run(self, dag_maker, session): + """Test that dry_run being True returns TaskInstances without clearing DagRun""" + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + task = EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.SUCCESS + session.merge(ti) + session.commit() + + request_json = {"dry_run": True} + + response = self.client.post( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/clear", + json=request_json, + environ_overrides={"REMOTE_USER": "test"}, + ) + + assert response.status_code == 200 + assert response.json == { + "task_instances": [ + { + "dag_id": dag_id, + "dag_run_id": dag_run_id, + "execution_date": dr.execution_date.isoformat(), + "task_id": "task_id", + } + ] + } + + ti.refresh_from_db() + assert ti.state == State.SUCCESS + + dr = session.query(DagRun).filter(DagRun.run_id == dr.run_id).first() + assert dr.state == "running" + + def test_should_raises_401_unauthenticated(self, session): + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear", + json={ + "dry_run": True, + }, + ) + + assert_401(response) + + def test_should_raise_403_forbidden(self): + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear", + json={ + "dry_run": True, + }, + environ_overrides={"REMOTE_USER": "test_no_permissions"}, + ) + assert response.status_code == 403 + + def test_should_respond_404(self): + response = self.client.post( + "api/v1/dags/INVALID_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1/clear", + json={ + "dry_run": True, + }, + environ_overrides={"REMOTE_USER": "test"}, + ) + assert response.status_code == 404