Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions providers/src/airflow/providers/amazon/aws/operators/ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
:param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait
in between each Cloudwatch logs fetches.
If deferrable is set to True, that parameter is ignored and waiter_delay is used instead.
:param container_name: The name of the container to fetch logs from. If not set, the first container is used.
:param quota_retry: Config if and how to retry the launch of a new ECS task, to handle
transient errors.
:param reattach: If set to True, will check if the task previously launched by the task_instance
Expand Down Expand Up @@ -414,6 +415,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
"awslogs_region",
"awslogs_stream_prefix",
"awslogs_fetch_interval",
"container_name",
"propagate_tags",
"reattach",
"number_logs_exception",
Expand Down Expand Up @@ -445,6 +447,7 @@ def __init__(
awslogs_region: str | None = None,
awslogs_stream_prefix: str | None = None,
awslogs_fetch_interval: timedelta = timedelta(seconds=30),
container_name: str | None = None,
propagate_tags: str | None = None,
quota_retry: dict | None = None,
reattach: bool = False,
Expand Down Expand Up @@ -484,7 +487,7 @@ def __init__(
self.awslogs_region = self.region_name

self.arn: str | None = None
self.container_name: str | None = None
self.container_name: str | None = container_name
self._started_by: str | None = None

self.retry_args = quota_retry
Expand Down Expand Up @@ -628,7 +631,8 @@ def _start_task(self):
self.log.info("ECS Task started: %s", response)

self.arn = response["tasks"][0]["taskArn"]
self.container_name = response["tasks"][0]["containers"][0]["name"]
if not self.container_name:
self.container_name = response["tasks"][0]["containers"][0]["name"]
self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))

def _try_reattach_task(self, started_by: str):
Expand Down
14 changes: 14 additions & 0 deletions providers/tests/amazon/aws/operators/test_ecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def test_template_fields_overrides(self):
"number_logs_exception",
"wait_for_completion",
"deferrable",
"container_name",
)

@pytest.mark.parametrize(
Expand Down Expand Up @@ -752,6 +753,19 @@ def test_execute_complete(self, client_mock):
# task gets described to assert its success
client_mock().describe_tasks.assert_called_once_with(cluster="test_cluster", tasks=["my_arn"])

@mock.patch.object(EcsBaseOperator, "client")
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
def test_container_name_in_log_stream(self, client_mock, log_fetcher_mock):
container_name = "container-name"
prefix = "prefix"
self.set_up_operator(
awslogs_group="awslogs-group",
awslogs_stream_prefix=prefix,
container_name=container_name
)

assert self.ecs._get_logs_stream_name().startswith(f"{prefix}/{container_name}/")


class TestEcsCreateClusterOperator(EcsBaseTestCase):
@pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES)
Expand Down