Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ class GlueJobHook(AwsBaseHook):
- :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

JOB_POLL_INTERVAL = 6 # polls job status after every JOB_POLL_INTERVAL seconds

class LogContinuationTokens:
"""Used to hold the continuation tokens when reading logs from both streams Glue Jobs write to."""

Expand All @@ -75,6 +73,7 @@ def __init__(
iam_role_name: str | None = None,
create_job_kwargs: dict | None = None,
update_config: bool = False,
job_poll_interval: int | float = 6,
*args,
**kwargs,
):
Expand All @@ -88,6 +87,7 @@ def __init__(
self.s3_glue_logs = "logs/glue-logs/"
self.create_job_kwargs = create_job_kwargs or {}
self.update_config = update_config
self.job_poll_interval = job_poll_interval

worker_type_exists = "WorkerType" in self.create_job_kwargs
num_workers_exists = "NumberOfWorkers" in self.create_job_kwargs
Expand Down Expand Up @@ -278,7 +278,7 @@ def job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> d
if ret:
return ret
else:
time.sleep(self.JOB_POLL_INTERVAL)
time.sleep(self.job_poll_interval)

async def async_job_completion(self, job_name: str, run_id: str, verbose: bool = False) -> dict[str, str]:
"""
Expand All @@ -297,7 +297,7 @@ async def async_job_completion(self, job_name: str, run_id: str, verbose: bool =
if ret:
return ret
else:
await asyncio.sleep(self.JOB_POLL_INTERVAL)
await asyncio.sleep(self.job_poll_interval)

def _handle_state(
self,
Expand Down
4 changes: 4 additions & 0 deletions airflow/providers/amazon/aws/operators/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
deferrable: bool = False,
verbose: bool = False,
update_config: bool = False,
job_poll_interval: int | float = 6,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -121,6 +122,7 @@ def __init__(
self.verbose = verbose
self.update_config = update_config
self.deferrable = deferrable
self.job_poll_interval = job_poll_interval

def execute(self, context: Context):
"""Execute AWS Glue Job from Airflow.
Expand Down Expand Up @@ -151,6 +153,7 @@ def execute(self, context: Context):
iam_role_name=self.iam_role_name,
create_job_kwargs=self.create_job_kwargs,
update_config=self.update_config,
job_poll_interval=self.job_poll_interval,
)
self.log.info(
"Initializing AWS Glue Job: %s. Wait for completion: %s",
Expand Down Expand Up @@ -181,6 +184,7 @@ def execute(self, context: Context):
run_id=glue_job_run["JobRunId"],
verbose=self.verbose,
aws_conn_id=self.aws_conn_id,
job_poll_interval=self.job_poll_interval,
),
method_name="execute_complete",
)
Expand Down
6 changes: 5 additions & 1 deletion airflow/providers/amazon/aws/triggers/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,14 @@ def __init__(
run_id: str,
verbose: bool,
aws_conn_id: str,
job_poll_interval: int | float,
):
super().__init__()
self.job_name = job_name
self.run_id = run_id
self.verbose = verbose
self.aws_conn_id = aws_conn_id
self.job_poll_interval = job_poll_interval

def serialize(self) -> tuple[str, dict[str, Any]]:
return (
Expand All @@ -54,10 +57,11 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"run_id": self.run_id,
"verbose": str(self.verbose),
"aws_conn_id": self.aws_conn_id,
"job_poll_interval": self.job_poll_interval,
},
)

async def run(self) -> AsyncIterator[TriggerEvent]:
hook = GlueJobHook(aws_conn_id=self.aws_conn_id)
hook = GlueJobHook(aws_conn_id=self.aws_conn_id, job_poll_interval=self.job_poll_interval)
await hook.async_job_completion(self.job_name, self.run_id, self.verbose)
yield TriggerEvent({"status": "success", "message": "Job done", "value": self.run_id})
12 changes: 4 additions & 8 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ def test_print_job_logs_no_stream_yet(self, conn_mock: MagicMock, client_mock: M

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -368,8 +367,7 @@ def test_job_completion_success(self, get_state_mock: MagicMock):

@mock.patch.object(GlueJobHook, "get_job_state")
def test_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -384,8 +382,7 @@ def test_job_completion_failure(self, get_state_mock: MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_success(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand All @@ -400,8 +397,7 @@ async def test_async_job_completion_success(self, get_state_mock: MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_async_job_completion_failure(self, get_state_mock: MagicMock):
hook = GlueJobHook()
hook.JOB_POLL_INTERVAL = 0
hook = GlueJobHook(job_poll_interval=0)
get_state_mock.side_effect = [
"RUNNING",
"RUNNING",
Expand Down
4 changes: 2 additions & 2 deletions tests/providers/amazon/aws/triggers/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class TestGlueJobTrigger:
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_wait_job(self, get_state_mock: mock.MagicMock):
GlueJobHook.JOB_POLL_INTERVAL = 0.1
trigger = GlueJobCompleteTrigger(
job_name="job_name",
run_id="JobRunId",
verbose=False,
aws_conn_id="aws_conn_id",
job_poll_interval=0.1,
)
get_state_mock.side_effect = [
"RUNNING",
Expand All @@ -52,12 +52,12 @@ async def test_wait_job(self, get_state_mock: mock.MagicMock):
@pytest.mark.asyncio
@mock.patch.object(GlueJobHook, "async_get_job_state")
async def test_wait_job_failed(self, get_state_mock: mock.MagicMock):
GlueJobHook.JOB_POLL_INTERVAL = 0.1
trigger = GlueJobCompleteTrigger(
job_name="job_name",
run_id="JobRunId",
verbose=False,
aws_conn_id="aws_conn_id",
job_poll_interval=0.1,
)
get_state_mock.side_effect = [
"RUNNING",
Expand Down