diff --git a/airflow/providers/amazon/aws/hooks/glue.py b/airflow/providers/amazon/aws/hooks/glue.py index dcd6d7c4661cd..be85f2fec4cf3 100644 --- a/airflow/providers/amazon/aws/hooks/glue.py +++ b/airflow/providers/amazon/aws/hooks/glue.py @@ -181,6 +181,12 @@ def get_or_create_glue_job(self) -> str: s3_log_path = f's3://{self.s3_bucket}/{self.s3_glue_logs}{self.job_name}' execution_role = self.get_iam_execution_role() try: + default_command = { + "Name": "glueetl", + "ScriptLocation": self.script_location, + } + command = self.create_job_kwargs.get("Command", default_command) + if "WorkerType" in self.create_job_kwargs and "NumberOfWorkers" in self.create_job_kwargs: create_job_response = glue_client.create_job( Name=self.job_name, @@ -188,7 +194,7 @@ def get_or_create_glue_job(self) -> str: LogUri=s3_log_path, Role=execution_role['Role']['Arn'], ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, - Command={"Name": "glueetl", "ScriptLocation": self.script_location}, + Command=command, MaxRetries=self.retry_limit, **self.create_job_kwargs, ) @@ -199,7 +205,7 @@ def get_or_create_glue_job(self) -> str: LogUri=s3_log_path, Role=execution_role['Role']['Arn'], ExecutionProperty={"MaxConcurrentRuns": self.concurrent_run_limit}, - Command={"Name": "glueetl", "ScriptLocation": self.script_location}, + Command=command, MaxRetries=self.retry_limit, MaxCapacity=self.num_of_dpus, **self.create_job_kwargs, diff --git a/tests/providers/amazon/aws/hooks/test_glue.py b/tests/providers/amazon/aws/hooks/test_glue.py index 75c2c9a9c5138..4e3afd0453834 100644 --- a/tests/providers/amazon/aws/hooks/test_glue.py +++ b/tests/providers/amazon/aws/hooks/test_glue.py @@ -22,9 +22,9 @@ from airflow.providers.amazon.aws.hooks.glue import GlueJobHook try: - from moto import mock_iam + from moto import mock_glue, mock_iam except ImportError: - mock_iam = None + mock_iam = mock_glue = None class TestGlueJobHook(unittest.TestCase): @@ -57,23 +57,56 @@ def test_get_iam_execution_role(self): assert "Arn" in iam_role['Role'] assert iam_role['Role']['Arn'] == "arn:aws:iam::123456789012:role/my_test_role" - @mock.patch.object(GlueJobHook, "get_iam_execution_role") @mock.patch.object(GlueJobHook, "get_conn") - def test_get_or_create_glue_job(self, mock_get_conn, mock_get_iam_execution_role): - mock_get_iam_execution_role.return_value = mock.MagicMock(Role={'RoleName': 'my_test_role'}) + def test_get_or_create_glue_job_get_existing_job(self, mock_get_conn): + """ + Calls 'get_or_create_glue_job' with a existing job. + Should retrieve existing one. + """ + expected_job_name = "simple-job" + mock_get_conn.return_value.get_job.return_value = {"Job": {"Name": expected_job_name}} + some_script = "s3:/glue-examples/glue-scripts/sample_aws_glue_job.py" some_s3_bucket = "my-includes" - mock_glue_job = mock_get_conn.return_value.get_job()['Job']['Name'] - glue_job = GlueJobHook( - job_name='aws_test_glue_job', - desc='This is test case job from Airflow', + hook = GlueJobHook( + job_name="aws_test_glue_job", + desc="This is test case job from Airflow", script_location=some_script, - iam_role_name='my_test_role', + iam_role_name="my_test_role", s3_bucket=some_s3_bucket, region_name=self.some_aws_region, - ).get_or_create_glue_job() - assert glue_job == mock_glue_job + ) + + result = hook.get_or_create_glue_job() + + mock_get_conn.assert_called_once() + mock_get_conn.return_value.get_job.assert_called_once_with(JobName=hook.job_name) + assert result == expected_job_name + + @unittest.skipIf(mock_glue is None, "mock_glue package not present") + @mock_glue + @mock.patch.object(GlueJobHook, "get_iam_execution_role") + def test_get_or_create_glue_job_create_new_job(self, mock_get_iam_execution_role): + """ + Calls 'get_or_create_glue_job' with no existing job. + Should create a new job. + """ + mock_get_iam_execution_role.return_value = {"Role": {"RoleName": "my_test_role", "Arn": "test_role"}} + expected_job_name = "aws_test_glue_job" + + hook = GlueJobHook( + job_name=expected_job_name, + desc="This is test case job from Airflow", + iam_role_name="my_test_role", + script_location="s3://bucket", + s3_bucket="bucket", + region_name=self.some_aws_region, + ) + + result = hook.get_or_create_glue_job() + + assert result == expected_job_name @mock.patch.object(GlueJobHook, "get_iam_execution_role") @mock.patch.object(GlueJobHook, "get_conn")