diff --git a/airflow/providers/amazon/aws/operators/glue.py b/airflow/providers/amazon/aws/operators/glue.py index 1d6146e42b9a2..060ac358a40c2 100644 --- a/airflow/providers/amazon/aws/operators/glue.py +++ b/airflow/providers/amazon/aws/operators/glue.py @@ -19,6 +19,7 @@ import os.path import urllib.parse +from functools import cached_property from typing import TYPE_CHECKING, Sequence from airflow import AirflowException @@ -60,6 +61,7 @@ class GlueJobOperator(BaseOperator): (default: False) :param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False) :param update_config: If True, Operator will update job configuration. (default: False) + :param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed. """ template_fields: Sequence[str] = ( @@ -100,6 +102,7 @@ def __init__( verbose: bool = False, update_config: bool = False, job_poll_interval: int | float = 6, + stop_job_run_on_kill: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -123,12 +126,11 @@ def __init__( self.update_config = update_config self.deferrable = deferrable self.job_poll_interval = job_poll_interval + self.stop_job_run_on_kill = stop_job_run_on_kill + self._job_run_id: str | None = None - def execute(self, context: Context): - """Execute AWS Glue Job from Airflow. - - :return: the current Glue job ID. - """ + @cached_property + def glue_job_hook(self) -> GlueJobHook: if self.script_location is None: s3_script_location = None elif not self.script_location.startswith(self.s3_protocol): @@ -140,7 +142,7 @@ def execute(self, context: Context): s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}" else: s3_script_location = self.script_location - glue_job = GlueJobHook( + return GlueJobHook( job_name=self.job_name, desc=self.job_desc, concurrent_run_limit=self.concurrent_run_limit, @@ -155,25 +157,32 @@ def execute(self, context: Context): update_config=self.update_config, job_poll_interval=self.job_poll_interval, ) + + def execute(self, context: Context): + """Execute AWS Glue Job from Airflow. + + :return: the current Glue job ID. + """ self.log.info( "Initializing AWS Glue Job: %s. Wait for completion: %s", self.job_name, self.wait_for_completion, ) - glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs) + glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs) + self._job_run_id = glue_job_run["JobRunId"] glue_job_run_url = GlueJobRunDetailsLink.format_str.format( - aws_domain=GlueJobRunDetailsLink.get_aws_domain(glue_job.conn_partition), - region_name=glue_job.conn_region_name, + aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition), + region_name=self.glue_job_hook.conn_region_name, job_name=urllib.parse.quote(self.job_name, safe=""), - job_run_id=glue_job_run["JobRunId"], + job_run_id=self._job_run_id, ) GlueJobRunDetailsLink.persist( context=context, operator=self, - region_name=glue_job.conn_region_name, - aws_partition=glue_job.conn_partition, + region_name=self.glue_job_hook.conn_region_name, + aws_partition=self.glue_job_hook.conn_partition, job_name=urllib.parse.quote(self.job_name, safe=""), - job_run_id=glue_job_run["JobRunId"], + job_run_id=self._job_run_id, ) self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url) @@ -181,7 +190,7 @@ def execute(self, context: Context): self.defer( trigger=GlueJobCompleteTrigger( job_name=self.job_name, - run_id=glue_job_run["JobRunId"], + run_id=self._job_run_id, verbose=self.verbose, aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval, @@ -189,18 +198,29 @@ def execute(self, context: Context): method_name="execute_complete", ) elif self.wait_for_completion: - glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose) + glue_job_run = self.glue_job_hook.job_completion(self.job_name, self._job_run_id, self.verbose) self.log.info( "AWS Glue Job: %s status: %s. Run Id: %s", self.job_name, glue_job_run["JobRunState"], - glue_job_run["JobRunId"], + self._job_run_id, ) else: - self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"]) - return glue_job_run["JobRunId"] + self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id) + return self._job_run_id def execute_complete(self, context, event=None): if event["status"] != "success": raise AirflowException(f"Error in glue job: {event}") return event["value"] + + def on_kill(self): + """Cancel the running AWS Glue Job.""" + if self.stop_job_run_on_kill: + self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id) + response = self.glue_job_hook.conn.batch_stop_job_run( + JobName=self.job_name, + JobRunIds=[self._job_run_id], + ) + if not response["SuccessfulSubmissions"]: + self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id) diff --git a/tests/providers/amazon/aws/operators/test_glue.py b/tests/providers/amazon/aws/operators/test_glue.py index 03b5e154f47e4..9eed48e47adc9 100644 --- a/tests/providers/amazon/aws/operators/test_glue.py +++ b/tests/providers/amazon/aws/operators/test_glue.py @@ -207,3 +207,46 @@ def test_log_correct_url( assert job_run_id == JOB_RUN_ID mock_log_info.assert_any_call("You can monitor this Glue Job run at: %s", glue_job_run_url) + + @mock.patch.object(GlueJobHook, "conn") + @mock.patch.object(GlueJobHook, "get_conn") + def test_killed_without_stop_job_run_on_kill( + self, + _, + mock_glue_hook, + ): + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + ) + glue.on_kill() + mock_glue_hook.batch_stop_job_run.assert_not_called() + + @mock.patch.object(GlueJobHook, "conn") + @mock.patch.object(GlueJobHook, "get_conn") + def test_killed_with_stop_job_run_on_kill( + self, + _, + mock_glue_hook, + ): + glue = GlueJobOperator( + task_id=TASK_ID, + job_name=JOB_NAME, + script_location="s3://folder/file", + aws_conn_id="aws_default", + region_name="us-west-2", + s3_bucket="some_bucket", + iam_role_name="my_test_role", + stop_job_run_on_kill=True, + ) + glue._job_run_id = JOB_RUN_ID + glue.on_kill() + mock_glue_hook.batch_stop_job_run.assert_called_once_with( + JobName=JOB_NAME, + JobRunIds=[JOB_RUN_ID], + )