Skip to content
98 changes: 80 additions & 18 deletions airflow/providers/amazon/aws/operators/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import datetime
import json
import time
import warnings
Expand Down Expand Up @@ -375,6 +376,7 @@ class SageMakerEndpointOperator(SageMakerBaseOperator):
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
:param aws_conn_id: The AWS connection ID to use.
:param deferrable: Will wait asynchronously for completion.
:return Dict: Returns The ARN of the endpoint created in Amazon SageMaker.
"""

Expand All @@ -387,15 +389,17 @@ def __init__(
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
operation: str = "create",
deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.max_ingestion_time = max_ingestion_time or 3600 * 10
self.operation = operation.lower()
if self.operation not in ["create", "update"]:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.deferrable = deferrable

def _create_integer_fields(self) -> None:
"""Set fields which should be cast to integers."""
Expand Down Expand Up @@ -436,29 +440,54 @@ def execute(self, context: Context) -> dict:
try:
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
wait_for_completion=False,
)
# waiting for completion is handled here in the operator
except ClientError:
self.operation = "update"
sagemaker_operation = self.hook.update_endpoint
log_str = "Updating"
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
wait_for_completion=False,
)

if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker endpoint creation failed: {response}")
else:
return {
"EndpointConfig": serialize(
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])

if self.deferrable:
self.defer(
trigger=SageMakerTrigger(
job_name=endpoint_info["EndpointName"],
job_type="endpoint",
poke_interval=self.check_interval,
aws_conn_id=self.aws_conn_id,
),
"Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
}
method_name="execute_complete",
timeout=datetime.timedelta(seconds=self.max_ingestion_time),
)
elif self.wait_for_completion:
self.hook.get_waiter("endpoint_in_service").wait(
EndpointName=endpoint_info["EndpointName"],
WaiterConfig={"Delay": self.check_interval, "MaxAttempts": self.max_ingestion_time},
)

return {
"EndpointConfig": serialize(
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
),
"Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
}

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
endpoint_info = self.config.get("Endpoint", self.config)
return {
"EndpointConfig": serialize(
self.hook.describe_endpoint_config(endpoint_info["EndpointConfigName"])
),
"Endpoint": serialize(self.hook.describe_endpoint(endpoint_info["EndpointName"])),
}


class SageMakerTransformOperator(SageMakerBaseOperator):
Expand Down Expand Up @@ -652,6 +681,7 @@ class SageMakerTuningOperator(SageMakerBaseOperator):
:param max_ingestion_time: If wait is set to True, the operation fails
if the tuning job doesn't finish within max_ingestion_time seconds. If you
set this parameter to None, the operation does not timeout.
:param deferrable: Will wait asynchronously for completion.
:return Dict: Returns The ARN of the tuning job created in Amazon SageMaker.
"""

Expand All @@ -663,12 +693,14 @@ def __init__(
wait_for_completion: bool = True,
check_interval: int = CHECK_INTERVAL_SECOND,
max_ingestion_time: int | None = None,
deferrable: bool = False,
**kwargs,
):
super().__init__(config=config, aws_conn_id=aws_conn_id, **kwargs)
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.deferrable = deferrable

def expand_role(self) -> None:
"""Expands an IAM role name into an ARN."""
Expand All @@ -695,16 +727,46 @@ def execute(self, context: Context) -> dict:
)
response = self.hook.create_tuning_job(
self.config,
wait_for_completion=self.wait_for_completion,
wait_for_completion=False, # we handle this here
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
raise AirflowException(f"Sagemaker Tuning Job creation failed: {response}")

if self.deferrable:
self.defer(
trigger=SageMakerTrigger(
job_name=self.config["HyperParameterTuningJobName"],
job_type="tuning",
poke_interval=self.check_interval,
aws_conn_id=self.aws_conn_id,
),
method_name="execute_complete",
timeout=datetime.timedelta(seconds=self.max_ingestion_time)
if self.max_ingestion_time is not None
else None,
)
description = {} # never executed but makes static checkers happy
elif self.wait_for_completion:
description = self.hook.check_status(
self.config["HyperParameterTuningJobName"],
"HyperParameterTuningJobStatus",
self.hook.describe_tuning_job,
self.check_interval,
self.max_ingestion_time,
)
else:
return {
"Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
}
description = self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"])

return {"Tuning": serialize(description)}

def execute_complete(self, context, event=None):
if event["status"] != "success":
raise AirflowException(f"Error while running job: {event}")
return {
"Tuning": serialize(self.hook.describe_tuning_job(self.config["HyperParameterTuningJobName"]))
}


class SageMakerModelOperator(SageMakerBaseOperator):
Expand Down
38 changes: 27 additions & 11 deletions airflow/providers/amazon/aws/triggers/sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from typing import Any

from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
from airflow.triggers.base import BaseTrigger, TriggerEvent


Expand All @@ -41,7 +42,7 @@ def __init__(
job_name: str,
job_type: str,
poke_interval: int = 30,
max_attempts: int | None = None,
max_attempts: int = 480,
Comment on lines -44 to +45
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing None was not working before, so this is not a breaking change since it was already broken.

aws_conn_id: str = "aws_default",
):
super().__init__()
Expand Down Expand Up @@ -74,14 +75,28 @@ def _get_job_type_waiter(job_type: str) -> str:
"training": "TrainingJobComplete",
"transform": "TransformJobComplete",
"processing": "ProcessingJobComplete",
"tuning": "TuningJobComplete",
"endpoint": "endpoint_in_service", # this one is provided by boto
}[job_type.lower()]

@staticmethod
def _get_job_type_waiter_job_name_arg(job_type: str) -> str:
def _get_waiter_arg_name(job_type: str) -> str:
return {
"training": "TrainingJobName",
"transform": "TransformJobName",
"processing": "ProcessingJobName",
"tuning": "HyperParameterTuningJobName",
"endpoint": "EndpointName",
}[job_type.lower()]

@staticmethod
def _get_response_status_key(job_type: str) -> str:
return {
"training": "TrainingJobStatus",
"transform": "TransformJobStatus",
"processing": "ProcessingJobStatus",
"tuning": "HyperParameterTuningJobStatus",
"endpoint": "EndpointStatus",
}[job_type.lower()]

async def run(self):
Expand All @@ -90,12 +105,13 @@ async def run(self):
waiter = self.hook.get_waiter(
self._get_job_type_waiter(self.job_type), deferrable=True, client=client
)
waiter_args = {
self._get_job_type_waiter_job_name_arg(self.job_type): self.job_name,
"WaiterConfig": {
"Delay": self.poke_interval,
"MaxAttempts": self.max_attempts,
},
}
await waiter.wait(**waiter_args)
yield TriggerEvent({"status": "success", "message": "Job completed."})
await async_wait(
waiter=waiter,
waiter_delay=self.poke_interval,
waiter_max_attempts=self.max_attempts,
args={self._get_waiter_arg_name(self.job_type): self.job_name},
failure_message=f"Error while waiting for {self.job_type} job",
status_message=f"{self.job_type} job not done yet",
status_args=[self._get_response_status_key(self.job_type)],
)
yield TriggerEvent({"status": "success", "message": "Job completed."})
26 changes: 26 additions & 0 deletions airflow/providers/amazon/aws/waiters/sagemaker.json
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@
"state": "failure"
}
]
},
"TuningJobComplete": {
"delay": 30,
"operation": "DescribeHyperParameterTuningJob",
"maxAttempts": 60,
"description": "Wait until job is COMPLETED",
"acceptors": [
{
"matcher": "path",
"argument": "HyperParameterTuningJobStatus",
"expected": "Completed",
"state": "success"
},
{
"matcher": "path",
"argument": "HyperParameterTuningJobStatus",
"expected": "Failed",
"state": "failure"
},
{
"matcher": "path",
"argument": "HyperParameterTuningJobStatus",
"expected": "Stopped",
"state": "failure"
}
]
}
}
}
24 changes: 20 additions & 4 deletions tests/providers/amazon/aws/operators/test_sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
import pytest
from botocore.exceptions import ClientError

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger

CREATE_MODEL_PARAMS: dict = {
"ModelName": "model_name",
Expand Down Expand Up @@ -83,12 +84,12 @@ def test_integer_fields(self, serialize, mock_endpoint, mock_endpoint_config, mo
@mock.patch.object(sagemaker, "serialize", return_value="")
def test_execute(self, serialize, mock_endpoint, mock_endpoint_config, mock_model, mock_client):
mock_endpoint.return_value = {"EndpointArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 200}}

self.sagemaker.execute(None)

mock_model.assert_called_once_with(CREATE_MODEL_PARAMS)
mock_endpoint_config.assert_called_once_with(CREATE_ENDPOINT_CONFIG_PARAMS)
mock_endpoint.assert_called_once_with(
CREATE_ENDPOINT_PARAMS, wait_for_completion=False, check_interval=5, max_ingestion_time=None
)
mock_endpoint.assert_called_once_with(CREATE_ENDPOINT_PARAMS, wait_for_completion=False)
assert self.sagemaker.integer_fields == EXPECTED_INTEGER_FIELDS
for variant in self.sagemaker.config["EndpointConfig"]["ProductionVariants"]:
assert variant["InitialInstanceCount"] == int(variant["InitialInstanceCount"])
Expand Down Expand Up @@ -120,3 +121,18 @@ def test_execute_with_duplicate_endpoint_creation(
"ResponseMetadata": {"HTTPStatusCode": 200},
}
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "create_model")
@mock.patch.object(SageMakerHook, "create_endpoint_config")
@mock.patch.object(SageMakerHook, "create_endpoint")
def test_deferred(self, mock_create_endpoint, _, __):
self.sagemaker.deferrable = True

mock_create_endpoint.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}

with pytest.raises(TaskDeferred) as defer:
self.sagemaker.execute(None)

assert isinstance(defer.value.trigger, SageMakerTrigger)
assert defer.value.trigger.job_name == "endpoint_name"
assert defer.value.trigger.job_type == "endpoint"
15 changes: 14 additions & 1 deletion tests/providers/amazon/aws/operators/test_sagemaker_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@

import pytest

from airflow.exceptions import AirflowException
from airflow.exceptions import AirflowException, TaskDeferred
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
from airflow.providers.amazon.aws.operators import sagemaker
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerTuningOperator
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerTrigger

EXPECTED_INTEGER_FIELDS: list[list[str]] = [
["HyperParameterTuningJobConfig", "ResourceLimits", "MaxNumberOfTrainingJobs"],
Expand Down Expand Up @@ -107,3 +108,15 @@ def test_execute_with_failure(self, mock_tuning, mock_client):
mock_tuning.return_value = {"TrainingJobArn": "test_arn", "ResponseMetadata": {"HTTPStatusCode": 404}}
with pytest.raises(AirflowException):
self.sagemaker.execute(None)

@mock.patch.object(SageMakerHook, "create_tuning_job")
def test_defers(self, create_mock):
create_mock.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}}
self.sagemaker.deferrable = True

with pytest.raises(TaskDeferred) as defer:
self.sagemaker.execute(None)

assert isinstance(defer.value.trigger, SageMakerTrigger)
assert defer.value.trigger.job_name == "job_name"
assert defer.value.trigger.job_type == "tuning"
26 changes: 12 additions & 14 deletions tests/providers/amazon/aws/triggers/test_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,26 @@ def test_sagemaker_trigger_serialize(self):
assert args["aws_conn_id"] == AWS_CONN_ID

@pytest.mark.asyncio
@pytest.mark.parametrize(
"job_type",
[
"training",
"transform",
"processing",
"tuning",
"endpoint",
],
)
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.get_waiter")
@mock.patch("airflow.providers.amazon.aws.hooks.sagemaker.SageMakerHook.async_conn")
@mock.patch("airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter")
@mock.patch(
"airflow.providers.amazon.aws.triggers.sagemaker.SageMakerTrigger._get_job_type_waiter_job_name_arg"
)
async def test_sagemaker_trigger_run(
self,
mock_get_job_type_waiter_job_name_arg,
mock_get_job_type_waiter,
mock_async_conn,
mock_get_waiter,
):
mock_get_job_type_waiter_job_name_arg.return_value = "job_name"
mock_get_job_type_waiter.return_value = "waiter"
async def test_sagemaker_trigger_run_all_job_types(self, mock_async_conn, mock_get_waiter, job_type):
mock_async_conn.__aenter__.return_value = mock.MagicMock()

mock_get_waiter().wait = AsyncMock()

sagemaker_trigger = SageMakerTrigger(
job_name=JOB_NAME,
job_type=JOB_TYPE,
job_type=job_type,
poke_interval=POKE_INTERVAL,
max_attempts=MAX_ATTEMPTS,
aws_conn_id=AWS_CONN_ID,
Expand Down
Loading