Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions airflow/providers/amazon/aws/hooks/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,14 +181,20 @@ def get_or_create_glue_job(self) -> str:
s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}'
execution_role = self.get_iam_execution_role()
try:
default_command = {
"Name": "glueetl",
"ScriptLocation": self.script_location,
}
command = self.create_job_kwargs.get("Command", default_command)

if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs:
create_job_response = glue_client.create_job(
Name=self.job_name,
Description=self.desc,
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
Command=command,
MaxRetries=self.retry_limit,
**self.create_job_kwargs,
)
Expand All @@ -199,7 +205,7 @@ def get_or_create_glue_job(self) -> str:
LogUri=s3_log_path,
Role=execution_role['Role']['Arn'],
ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit},
Command={"Name": "glueetl", "ScriptLocation": self.script_location},
Command=command,
MaxRetries=self.retry_limit,
MaxCapacity=self.num_of_dpus,
**self.create_job_kwargs,
Expand Down
57 changes: 45 additions & 12 deletions tests/providers/amazon/aws/hooks/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from airflow.providers.amazon.aws.hooks.glue import GlueJobHook

try:
from moto import mock_iam
from moto import mock_glue, mock_iam
except ImportError:
mock_iam = None
mock_iam = mock_glue = None


class TestGlueJobHook(unittest.TestCase):
Expand Down Expand Up @@ -57,23 +57,56 @@ def test_get_iam_execution_role(self):
assert "Arn" in iam_role['Role']
assert iam_role['Role']['Arn'] == "arn:aws:iam::123456789012:role/my_test_role"

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's a existing job, this method doesn't get called at all.

@mock.patch.object(GlueJobHook, "get_conn")
def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role):
mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'})
def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn):
"""
Calls 'get_or_create_glue_job' with a existing job.
Should retrieve existing one.
"""
expected_job_name = "simple-job"
mock_get_conn.return_value.get_job.return_value = {"Job": {"Name": expected_job_name}}

some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py"
some_s3_bucket = "my-includes"

mock_glue_job = mock_get_conn.return_value.get_job()['Job']['Name']
glue_job = GlueJobHook(
job_name='aws_test_glue_job',
desc='This is test case job from Airflow',
hook = GlueJobHook(
job_name="aws_test_glue_job",
desc="This is test case job from Airflow",
script_location=some_script,
iam_role_name='my_test_role',
iam_role_name="my_test_role",
s3_bucket=some_s3_bucket,
region_name=self.some_aws_region,
).get_or_create_glue_job()
assert glue_job == mock_glue_job
)

result = hook.get_or_create_glue_job()

mock_get_conn.assert_called_once()
mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name)
assert result == expected_job_name

@unittest.skipIf(mock_glue is None, "mock_glue package not present")
@mock_glue
@mock.patch.object(GlueJobHook, "get_iam_execution_role")
def test_get_or_create_glue_job_create_new_job(self, mock_get_iam_execution_role):
"""
Calls 'get_or_create_glue_job' with no existing job.
Should create a new job.
"""
mock_get_iam_execution_role.return_value = {"Role": {"RoleName": "my_test_role", "Arn": "test_role"}}
expected_job_name = "aws_test_glue_job"

hook = GlueJobHook(
job_name=expected_job_name,
desc="This is test case job from Airflow",
iam_role_name="my_test_role",
script_location="s3://bucket",
s3_bucket="bucket",
region_name=self.some_aws_region,
)

result = hook.get_or_create_glue_job()

assert result == expected_job_name

@mock.patch.object(GlueJobHook, "get_iam_execution_role")
@mock.patch.object(GlueJobHook, "get_conn")
Expand Down