Skip to content

Commit 1d60332

Browse files
Add an option to GlueJobOperator to stop the job run when the TI is killed (#32155)
--------- Signed-off-by: Hussein Awala <[email protected]>
1 parent 98c47f4 commit 1d60332

File tree

2 files changed

+81
-18
lines changed

2 files changed

+81
-18
lines changed

airflow/providers/amazon/aws/operators/glue.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import os.path
2121
import urllib.parse
22+
from functools import cached_property
2223
from typing import TYPE_CHECKING, Sequence
2324

2425
from airflow import AirflowException
@@ -60,6 +61,7 @@ class GlueJobOperator(BaseOperator):
6061
(default: False)
6162
:param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False)
6263
:param update_config: If True, Operator will update job configuration. (default: False)
64+
:param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed.
6365
"""
6466

6567
template_fields: Sequence[str] = (
@@ -100,6 +102,7 @@ def __init__(
100102
verbose: bool = False,
101103
update_config: bool = False,
102104
job_poll_interval: int | float = 6,
105+
stop_job_run_on_kill: bool = False,
103106
**kwargs,
104107
):
105108
super().__init__(**kwargs)
@@ -123,12 +126,11 @@ def __init__(
123126
self.update_config = update_config
124127
self.deferrable = deferrable
125128
self.job_poll_interval = job_poll_interval
129+
self.stop_job_run_on_kill = stop_job_run_on_kill
130+
self._job_run_id: str | None = None
126131

127-
def execute(self, context: Context):
128-
"""Execute AWS Glue Job from Airflow.
129-
130-
:return: the current Glue job ID.
131-
"""
132+
@cached_property
133+
def glue_job_hook(self) -> GlueJobHook:
132134
if self.script_location is None:
133135
s3_script_location = None
134136
elif not self.script_location.startswith(self.s3_protocol):
@@ -140,7 +142,7 @@ def execute(self, context: Context):
140142
s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
141143
else:
142144
s3_script_location = self.script_location
143-
glue_job = GlueJobHook(
145+
return GlueJobHook(
144146
job_name=self.job_name,
145147
desc=self.job_desc,
146148
concurrent_run_limit=self.concurrent_run_limit,
@@ -155,52 +157,70 @@ def execute(self, context: Context):
155157
update_config=self.update_config,
156158
job_poll_interval=self.job_poll_interval,
157159
)
160+
161+
def execute(self, context: Context):
162+
"""Execute AWS Glue Job from Airflow.
163+
164+
:return: the current Glue job ID.
165+
"""
158166
self.log.info(
159167
"Initializing AWS Glue Job: %s. Wait for completion: %s",
160168
self.job_name,
161169
self.wait_for_completion,
162170
)
163-
glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs)
171+
glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs)
172+
self._job_run_id = glue_job_run["JobRunId"]
164173
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
165-
aws_domain=GlueJobRunDetailsLink.get_aws_domain(glue_job.conn_partition),
166-
region_name=glue_job.conn_region_name,
174+
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition),
175+
region_name=self.glue_job_hook.conn_region_name,
167176
job_name=urllib.parse.quote(self.job_name, safe=""),
168-
job_run_id=glue_job_run["JobRunId"],
177+
job_run_id=self._job_run_id,
169178
)
170179
GlueJobRunDetailsLink.persist(
171180
context=context,
172181
operator=self,
173-
region_name=glue_job.conn_region_name,
174-
aws_partition=glue_job.conn_partition,
182+
region_name=self.glue_job_hook.conn_region_name,
183+
aws_partition=self.glue_job_hook.conn_partition,
175184
job_name=urllib.parse.quote(self.job_name, safe=""),
176-
job_run_id=glue_job_run["JobRunId"],
185+
job_run_id=self._job_run_id,
177186
)
178187
self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url)
179188

180189
if self.deferrable:
181190
self.defer(
182191
trigger=GlueJobCompleteTrigger(
183192
job_name=self.job_name,
184-
run_id=glue_job_run["JobRunId"],
193+
run_id=self._job_run_id,
185194
verbose=self.verbose,
186195
aws_conn_id=self.aws_conn_id,
187196
job_poll_interval=self.job_poll_interval,
188197
),
189198
method_name="execute_complete",
190199
)
191200
elif self.wait_for_completion:
192-
glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
201+
glue_job_run = self.glue_job_hook.job_completion(self.job_name, self._job_run_id, self.verbose)
193202
self.log.info(
194203
"AWS Glue Job: %s status: %s. Run Id: %s",
195204
self.job_name,
196205
glue_job_run["JobRunState"],
197-
glue_job_run["JobRunId"],
206+
self._job_run_id,
198207
)
199208
else:
200-
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
201-
return glue_job_run["JobRunId"]
209+
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
210+
return self._job_run_id
202211

203212
def execute_complete(self, context, event=None):
204213
if event["status"] != "success":
205214
raise AirflowException(f"Error in glue job: {event}")
206215
return event["value"]
216+
217+
def on_kill(self):
218+
"""Cancel the running AWS Glue Job."""
219+
if self.stop_job_run_on_kill:
220+
self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
221+
response = self.glue_job_hook.conn.batch_stop_job_run(
222+
JobName=self.job_name,
223+
JobRunIds=[self._job_run_id],
224+
)
225+
if not response["SuccessfulSubmissions"]:
226+
self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)

tests/providers/amazon/aws/operators/test_glue.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,3 +207,46 @@ def test_log_correct_url(
207207
assert job_run_id == JOB_RUN_ID
208208

209209
mock_log_info.assert_any_call("You can monitor this Glue Job run at: %s", glue_job_run_url)
210+
211+
@mock.patch.object(GlueJobHook, "conn")
212+
@mock.patch.object(GlueJobHook, "get_conn")
213+
def test_killed_without_stop_job_run_on_kill(
214+
self,
215+
_,
216+
mock_glue_hook,
217+
):
218+
glue = GlueJobOperator(
219+
task_id=TASK_ID,
220+
job_name=JOB_NAME,
221+
script_location="s3://folder/file",
222+
aws_conn_id="aws_default",
223+
region_name="us-west-2",
224+
s3_bucket="some_bucket",
225+
iam_role_name="my_test_role",
226+
)
227+
glue.on_kill()
228+
mock_glue_hook.batch_stop_job_run.assert_not_called()
229+
230+
@mock.patch.object(GlueJobHook, "conn")
231+
@mock.patch.object(GlueJobHook, "get_conn")
232+
def test_killed_with_stop_job_run_on_kill(
233+
self,
234+
_,
235+
mock_glue_hook,
236+
):
237+
glue = GlueJobOperator(
238+
task_id=TASK_ID,
239+
job_name=JOB_NAME,
240+
script_location="s3://folder/file",
241+
aws_conn_id="aws_default",
242+
region_name="us-west-2",
243+
s3_bucket="some_bucket",
244+
iam_role_name="my_test_role",
245+
stop_job_run_on_kill=True,
246+
)
247+
glue._job_run_id = JOB_RUN_ID
248+
glue.on_kill()
249+
mock_glue_hook.batch_stop_job_run.assert_called_once_with(
250+
JobName=JOB_NAME,
251+
JobRunIds=[JOB_RUN_ID],
252+
)

0 commit comments

Comments
 (0)