Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions providers/src/airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
This implies waiting for completion. This mode requires aiobotocore module to be installed.
(default: False)
:param container_name: The name of the container to fetch logs from. If not set, the first container is used.
Comment thread
eliskovets marked this conversation as resolved.
Outdated
:param do_xcom_push: If True, the operator will push the ECS task ARN to XCom with key 'ecs_task_arn'.
Additionally, if logs are fetched, the last log message will be pushed to XCom with the key 'return_value'. (default: False)
"""
Expand Down Expand Up @@ -419,6 +420,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"number_logs_exception",
"wait_for_completion",
"deferrable",
"container_name",
)
template_fields_renderers = {
"overrides": "json",
Expand Down Expand Up @@ -455,6 +457,7 @@ def __init__(
# Set the default waiter duration to 70 days (attempts*delay)
# Airflow execution_timeout handles task timeout
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
container_name: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -484,7 +487,7 @@ def __init__(
self.awslogs_region = self.region_name

self.arn: str | None = None
self.container_name: str | None = None
self.container_name: str | None = container_name
self._started_by: str | None = None

self.retry_args = quota_retry
Expand Down Expand Up @@ -628,7 +631,8 @@ def _start_task(self):
self.log.info("ECS Task started: %s", response)

self.arn = response["tasks"][0]["taskArn"]
self.container_name = response["tasks"][0]["containers"][0]["name"]
if not self.container_name:
self.container_name = response["tasks"][0]["containers"][0]["name"]
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))

def _try_reattach_task(self, started_by: str):
Expand Down
14 changes: 14 additions & 0 deletions providers/tests/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_template_fields_overrides(self):
"number_logs_exception",
"wait_for_completion",
"deferrable",
"container_name",
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -752,6 +753,19 @@ def test_execute_complete(self, client_mock):
# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"])

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_container_name_in_log_stream(self, client_mock):
container_name = "container-name"
prefix = "prefix"
self.set_up_operator(
awslogs_group="awslogs-group",
awslogs_stream_prefix=prefix,
container_name=container_name
)

assert self.ecs._get_logs_stream_name() == f"{prefix}/{container_name}/{TASK_ID}"


class TestEcsCreateClusterOperator(EcsBaseTestCase):
@pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES)
Expand Down