Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import os.path
import urllib.parse
from functools import cached_property
from typing import TYPE_CHECKING, Sequence

from airflow import AirflowException
Expand Down Expand Up @@ -60,6 +61,7 @@ class GlueJobOperator(BaseOperator):
(default: False)
:param verbose: If True, Glue Job Run logs show in the Airflow Task Logs. (default: False)
:param update_config: If True, Operator will update job configuration. (default: False)
:param stop_job_run_on_kill: If True, Operator will stop the job run when task is killed.
"""

template_fields: Sequence[str] = (
Expand Down Expand Up @@ -100,6 +102,7 @@ def __init__(
verbose: bool = False,
update_config: bool = False,
job_poll_interval: int | float = 6,
stop_job_run_on_kill: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -123,12 +126,11 @@ def __init__(
self.update_config = update_config
self.deferrable = deferrable
self.job_poll_interval = job_poll_interval
self.stop_job_run_on_kill = stop_job_run_on_kill
self._job_run_id: str | None = None

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.

:return: the current Glue job ID.
"""
@cached_property
def glue_job_hook(self) -> GlueJobHook:
if self.script_location is None:
s3_script_location = None
elif not self.script_location.startswith(self.s3_protocol):
Expand All @@ -140,7 +142,7 @@ def execute(self, context: Context):
s3_script_location = f"s3://{self.s3_bucket}/{self.s3_artifacts_prefix}{script_name}"
else:
s3_script_location = self.script_location
glue_job = GlueJobHook(
return GlueJobHook(
job_name=self.job_name,
desc=self.job_desc,
concurrent_run_limit=self.concurrent_run_limit,
Expand All @@ -155,52 +157,70 @@ def execute(self, context: Context):
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
)

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.

:return: the current Glue job ID.
"""
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
self.job_name,
self.wait_for_completion,
)
glue_job_run = glue_job.initialize_job(self.script_args, self.run_job_kwargs)
glue_job_run = self.glue_job_hook.initialize_job(self.script_args, self.run_job_kwargs)
self._job_run_id = glue_job_run["JobRunId"]
glue_job_run_url = GlueJobRunDetailsLink.format_str.format(
aws_domain=GlueJobRunDetailsLink.get_aws_domain(glue_job.conn_partition),
region_name=glue_job.conn_region_name,
aws_domain=GlueJobRunDetailsLink.get_aws_domain(self.glue_job_hook.conn_partition),
region_name=self.glue_job_hook.conn_region_name,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=glue_job_run["JobRunId"],
job_run_id=self._job_run_id,
)
GlueJobRunDetailsLink.persist(
context=context,
operator=self,
region_name=glue_job.conn_region_name,
aws_partition=glue_job.conn_partition,
region_name=self.glue_job_hook.conn_region_name,
aws_partition=self.glue_job_hook.conn_partition,
job_name=urllib.parse.quote(self.job_name, safe=""),
job_run_id=glue_job_run["JobRunId"],
job_run_id=self._job_run_id,
)
self.log.info("You can monitor this Glue Job run at: %s", glue_job_run_url)

if self.deferrable:
self.defer(
trigger=GlueJobCompleteTrigger(
job_name=self.job_name,
run_id=glue_job_run["JobRunId"],
run_id=self._job_run_id,
verbose=self.verbose,
aws_conn_id=self.aws_conn_id,
job_poll_interval=self.job_poll_interval,
),
method_name="execute_complete",
)
elif self.wait_for_completion:
glue_job_run = glue_job.job_completion(self.job_name, glue_job_run["JobRunId"], self.verbose)
glue_job_run = self.glue_job_hook.job_completion(self.job_name, self._job_run_id, self.verbose)
self.log.info(
"AWS Glue Job: %s status: %s. Run Id: %s",
self.job_name,
glue_job_run["JobRunState"],
glue_job_run["JobRunId"],
self._job_run_id,
)
else:
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, glue_job_run["JobRunId"])
return glue_job_run["JobRunId"]
self.log.info("AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
return self._job_run_id

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error in glue job: {event}")
return event["value"]

def on_kill(self):
"""Cancel the running AWS Glue Job."""
if self.stop_job_run_on_kill:
self.log.info("Stopping AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
response = self.glue_job_hook.conn.batch_stop_job_run(
JobName=self.job_name,
JobRunIds=[self._job_run_id],
)
if not response["SuccessfulSubmissions"]:
self.log.error("Failed to stop AWS Glue Job: %s. Run Id: %s", self.job_name, self._job_run_id)
43 changes: 43 additions & 0 deletions tests/providers/amazon/aws/operators/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,46 @@ def test_log_correct_url(
assert job_run_id == JOB_RUN_ID

mock_log_info.assert_any_call("You can monitor this Glue Job run at: %s", glue_job_run_url)

@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(GlueJobHook, "get_conn")
def test_killed_without_stop_job_run_on_kill(
self,
_,
mock_glue_hook,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
)
glue.on_kill()
mock_glue_hook.batch_stop_job_run.assert_not_called()

@mock.patch.object(GlueJobHook, "conn")
@mock.patch.object(GlueJobHook, "get_conn")
def test_killed_with_stop_job_run_on_kill(
self,
_,
mock_glue_hook,
):
glue = GlueJobOperator(
task_id=TASK_ID,
job_name=JOB_NAME,
script_location="s3://folder/file",
aws_conn_id="aws_default",
region_name="us-west-2",
s3_bucket="some_bucket",
iam_role_name="my_test_role",
stop_job_run_on_kill=True,
)
glue._job_run_id = JOB_RUN_ID
glue.on_kill()
mock_glue_hook.batch_stop_job_run.assert_called_once_with(
JobName=JOB_NAME,
JobRunIds=[JOB_RUN_ID],
)