1919
2020import os .path
2121import urllib .parse
22+ from functools import cached_property
2223from typing import TYPE_CHECKING , Sequence
2324
2425from airflow import AirflowException
@@ -60,6 +61,7 @@ class GlueJobOperator(BaseOperator):
6061 (default: False)
6162 :param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False)
6263 :param update_config: If True, Operator will update job configuration. (default: False)
64+ :param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed.
6365 """
6466
6567 template_fields : Sequence [str ] = (
@@ -100,6 +102,7 @@ def __init__(
100102 verbose : bool = False ,
101103 update_config : bool = False ,
102104 job_poll_interval : int | float = 6 ,
105+ stop_job_run_on_kill : bool = False ,
103106 ** kwargs ,
104107 ):
105108 super ().__init__ (** kwargs )
@@ -123,12 +126,11 @@ def __init__(
123126 self .update_config = update_config
124127 self .deferrable = deferrable
125128 self .job_poll_interval = job_poll_interval
129+ self .stop_job_run_on_kill = stop_job_run_on_kill
130+ self ._job_run_id : str | None = None
126131
127- def execute (self , context : Context ):
128- """Execute AWS Glue Job from Airflow.
129-
130- :return: the current Glue job ID.
131- """
132+ @cached_property
133+ def glue_job_hook (self ) -> GlueJobHook :
132134 if self .script_location is None :
133135 s3_script_location = None
134136 elif not self .script_location .startswith (self .s3_protocol ):
@@ -140,7 +142,7 @@ def execute(self, context: Context):
140142 s3_script_location = f"s3://{ self .s3_bucket } /{ self .s3_artifacts_prefix } { script_name } "
141143 else :
142144 s3_script_location = self .script_location
143- glue_job = GlueJobHook (
145+ return GlueJobHook (
144146 job_name = self .job_name ,
145147 desc = self .job_desc ,
146148 concurrent_run_limit = self .concurrent_run_limit ,
@@ -155,52 +157,70 @@ def execute(self, context: Context):
155157 update_config = self .update_config ,
156158 job_poll_interval = self .job_poll_interval ,
157159 )
160+
161+ def execute (self , context : Context ):
162+ """Execute AWS Glue Job from Airflow.
163+
164+ :return: the current Glue job ID.
165+ """
158166 self .log .info (
159167 "Initializing AWS Glue Job: %s. Wait for completion: %s" ,
160168 self .job_name ,
161169 self .wait_for_completion ,
162170 )
163- glue_job_run = glue_job .initialize_job (self .script_args , self .run_job_kwargs )
171+ glue_job_run = self .glue_job_hook .initialize_job (self .script_args , self .run_job_kwargs )
172+ self ._job_run_id = glue_job_run ["JobRunId" ]
164173 glue_job_run_url = GlueJobRunDetailsLink .format_str .format (
165- aws_domain = GlueJobRunDetailsLink .get_aws_domain (glue_job .conn_partition ),
166- region_name = glue_job .conn_region_name ,
174+ aws_domain = GlueJobRunDetailsLink .get_aws_domain (self . glue_job_hook .conn_partition ),
175+ region_name = self . glue_job_hook .conn_region_name ,
167176 job_name = urllib .parse .quote (self .job_name , safe = "" ),
168- job_run_id = glue_job_run [ "JobRunId" ] ,
177+ job_run_id = self . _job_run_id ,
169178 )
170179 GlueJobRunDetailsLink .persist (
171180 context = context ,
172181 operator = self ,
173- region_name = glue_job .conn_region_name ,
174- aws_partition = glue_job .conn_partition ,
182+ region_name = self . glue_job_hook .conn_region_name ,
183+ aws_partition = self . glue_job_hook .conn_partition ,
175184 job_name = urllib .parse .quote (self .job_name , safe = "" ),
176- job_run_id = glue_job_run [ "JobRunId" ] ,
185+ job_run_id = self . _job_run_id ,
177186 )
178187 self .log .info ("You can monitor this Glue Job run at: %s" , glue_job_run_url )
179188
180189 if self .deferrable :
181190 self .defer (
182191 trigger = GlueJobCompleteTrigger (
183192 job_name = self .job_name ,
184- run_id = glue_job_run [ "JobRunId" ] ,
193+ run_id = self . _job_run_id ,
185194 verbose = self .verbose ,
186195 aws_conn_id = self .aws_conn_id ,
187196 job_poll_interval = self .job_poll_interval ,
188197 ),
189198 method_name = "execute_complete" ,
190199 )
191200 elif self .wait_for_completion :
192- glue_job_run = glue_job . job_completion (self .job_name , glue_job_run [ "JobRunId" ] , self .verbose )
201+ glue_job_run = self . glue_job_hook . job_completion (self .job_name , self . _job_run_id , self .verbose )
193202 self .log .info (
194203 "AWS Glue Job: %s status: %s. Run Id: %s" ,
195204 self .job_name ,
196205 glue_job_run ["JobRunState" ],
197- glue_job_run [ "JobRunId" ] ,
206+ self . _job_run_id ,
198207 )
199208 else :
200- self .log .info ("AWS Glue Job: %s. Run Id: %s" , self .job_name , glue_job_run [ "JobRunId" ] )
201- return glue_job_run [ "JobRunId" ]
209+ self .log .info ("AWS Glue Job: %s. Run Id: %s" , self .job_name , self . _job_run_id )
210+ return self . _job_run_id
202211
203212 def execute_complete (self , context , event = None ):
204213 if event ["status" ] != "success" :
205214 raise AirflowException (f"Error in glue job: { event } " )
206215 return event ["value" ]
216+
217+ def on_kill (self ):
218+ """Cancel the running AWS Glue Job."""
219+ if self .stop_job_run_on_kill :
220+ self .log .info ("Stopping AWS Glue Job: %s. Run Id: %s" , self .job_name , self ._job_run_id )
221+ response = self .glue_job_hook .conn .batch_stop_job_run (
222+ JobName = self .job_name ,
223+ JobRunIds = [self ._job_run_id ],
224+ )
225+ if not response ["SuccessfulSubmissions" ]:
226+ self .log .error ("Failed to stop AWS Glue Job: %s. Run Id: %s" , self .job_name , self ._job_run_id )
0 commit comments