diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py new file mode 100644 index 0000000000000..4baef440deec6 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum +from typing import Union + +import kubernetes.client as k8s +import kubernetes_asyncio.client as async_k8s + +client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api] + + +class ExecutionMode(str, Enum): + """Enum class for execution mode.""" + + SYNC = "sync" + ASYNC = "async" + + +class KubernetesPodOperatorCallback: + """`KubernetesPodOperator` callbacks methods. + + Currently, the callbacks methods are not called in the async mode, this support will be added + in the future. + """ + + @staticmethod + def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: + """Callback method called after creating the sync client. + + :param client: the created `kubernetes.client.CoreV1Api` client. + """ + pass + + @staticmethod + def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called after creating the pod. + + :param pod: the created pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called when the pod starts. + + :param pod: the started pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called when the pod completes. + + :param pod: the completed pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): + """Callback method called after cleaning/deleting the pod. + + :param pod: the completed pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_operator_resuming( + *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs + ) -> None: + """Callback method called when resuming the `KubernetesPodOperator` from deferred state. + + :param pod: the current state of the pod. + :param event: the returned event from the Trigger. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def progress_callback(*, line: str, client: client_type, mode: str, **kwargs) -> None: + """Callback method to process pod container logs. + + :param line: the read line of log. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 70f8bc2252bfb..0a06ea6ec6149 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -49,6 +49,7 @@ convert_volume, convert_volume_mount, ) +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import ( POD_NAME_MAX_LENGTH, @@ -198,7 +199,10 @@ class KubernetesPodOperator(BaseOperator): Default value is "File" :param active_deadline_seconds: The active_deadline_seconds which translates to active_deadline_seconds in V1PodSpec. + :param callbacks: KubernetesPodOperatorCallback instance contains the callbacks methods on different step + of KubernetesPodOperator. :param progress_callback: Callback function for receiving k8s container logs. + `progress_callback` is deprecated, please use :param `callbacks` instead. """ # !!! Changes in KubernetesPodOperator's arguments should be also reflected in !!! @@ -290,6 +294,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, + callbacks: type[KubernetesPodOperatorCallback] | None = None, progress_callback: Callable[[str], None] | None = None, **kwargs, ) -> None: @@ -381,6 +386,7 @@ def __init__( self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback + self.callbacks = callbacks self._killed: bool = False @cached_property @@ -459,7 +465,9 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool @cached_property def pod_manager(self) -> PodManager: - return PodManager(kube_client=self.client, progress_callback=self._progress_callback) + return PodManager( + kube_client=self.client, callbacks=self.callbacks, progress_callback=self._progress_callback + ) @cached_property def hook(self) -> PodOperatorHookProtocol: @@ -473,7 +481,10 @@ def hook(self) -> PodOperatorHookProtocol: @cached_property def client(self) -> CoreV1Api: - return self.hook.core_v1_client + client = self.hook.core_v1_client + if self.callbacks: + self.callbacks.on_sync_client_creation(client=client) + return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: """Return an already-running pod for this task instance if one exists.""" @@ -552,7 +563,17 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) + if self.callbacks: + self.callbacks.on_pod_creation( + pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC + ) self.await_pod_start(pod=self.pod) + if self.callbacks: + self.callbacks.on_pod_starting( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.get_logs: self.pod_manager.fetch_requested_container_logs( @@ -566,6 +587,12 @@ def execute_sync(self, context: Context): self.pod_manager.await_container_completion( pod=self.pod, container_name=self.base_container_name ) + if self.callbacks: + self.callbacks.on_pod_completion( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.do_xcom_push: self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) @@ -575,10 +602,13 @@ def execute_sync(self, context: Context): self.pod, istio_enabled, self.base_container_name ) finally: + pod_to_clean = self.pod or self.pod_request_obj self.cleanup( - pod=self.pod or self.pod_request_obj, + pod=pod_to_clean, remote_pod=self.remote_pod, ) + if self.callbacks: + self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) if self.do_xcom_push: return result @@ -589,6 +619,12 @@ def execute_async(self, context: Context): pod_request_obj=self.pod_request_obj, context=context, ) + if self.callbacks: + self.callbacks.on_pod_creation( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) @@ -625,6 +661,10 @@ def execute_complete(self, context: Context, event: dict, **kwargs): event["name"], event["namespace"], ) + if self.callbacks: + self.callbacks.on_operator_resuming( + pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC + ) if event["status"] in ("error", "failed", "timeout"): # fetch some logs when pod is failed if self.get_logs: @@ -677,6 +717,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs): pod=pod, remote_pod=remote_pod, ) + if self.callbacks: + self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): # If a task got marked as failed, "on_kill" method would be called and the pod will be cleaned up diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index e2d0efac830b9..0e736daa6aa76 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -40,6 +40,7 @@ from urllib3.exceptions import HTTPError, TimeoutError from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.pod_generator import PodDefaults from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import utcnow @@ -50,6 +51,7 @@ from kubernetes.client.models.v1_pod import V1Pod from urllib3.response import HTTPResponse + EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__" """ Sentinel for no xcom result. @@ -287,18 +289,22 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, + callbacks: type[KubernetesPodOperatorCallback] | None = None, progress_callback: Callable[[str], None] | None = None, ): """ Create the launcher. :param kube_client: kubernetes client + :param callbacks: :param progress_callback: Callback function invoked when fetching container log. + This parameter is deprecated, please use ```` """ super().__init__() self._client = kube_client self._progress_callback = progress_callback self._watch = watch.Watch() + self._callbacks = callbacks def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: """Run POD asynchronously.""" @@ -441,9 +447,13 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None message_timestamp = line_timestamp progress_callback_lines.append(line) else: # previous log line is complete - if self._progress_callback: - for line in progress_callback_lines: + for line in progress_callback_lines: + if self._progress_callback: self._progress_callback(line) + if self._callbacks: + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp message_to_log = message @@ -454,9 +464,13 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None progress_callback_lines.append(line) finally: # log the last line and update the last_captured_timestamp - if self._progress_callback: - for line in progress_callback_lines: + for line in progress_callback_lines: + if self._progress_callback: self._progress_callback(line) + if self._callbacks: + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp except TimeoutError as e: diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index 690a857ea02ea..f24fd61602efa 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -195,6 +195,81 @@ included in the exception message if the task fails. Read more on termination-log `here `__. +KubernetesPodOperator callbacks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`~airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator` supports different +callbacks that can be used to trigger actions during the lifecycle of the pod. In order to use them, you need to +create a subclass of :class:`~airflow.providers.cncf.kubernetes.callbacks.KubernetesPodOperatorCallback` and override +the callbacks methods you want to use. Then you can pass your callback class to the operator using the ``callbacks`` +parameter. + +The following callbacks are supported: + +* on_sync_client_creation: called after creating the sync client +* on_pod_creation: called after creating the pod +* on_pod_starting: called after the pod starts +* on_pod_completion: called when the pod completes +* on_pod_cleanup: called after cleaning/deleting the pod +* on_operator_resuming: when resuming the task from deferred state +* progress_callback: called on each line of containers logs + +Currently, the callbacks methods are not called in the async mode, this support will be added in the future. + +Example: +~~~~~~~~ +.. code-block:: python + + import kubernetes.client as k8s + import kubernetes_asyncio.client as async_k8s + + from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator + from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback + + + class MyCallback(KubernetesPodOperatorCallback): + @staticmethod + def on_pod_creation(*, pod: k8s.V1Pod, client: k8s.CoreV1Api, mode: str, **kwargs) -> None: + client.create_namespaced_service( + namespace=pod.metadata.namespace, + body=k8s.V1Service( + metadata=k8s.V1ObjectMeta( + name=pod.metadata.name, + labels=pod.metadata.labels, + owner_references=[ + k8s.V1OwnerReference( + api_version=pod.api_version, + kind=pod.kind, + name=pod.metadata.name, + uid=pod.metadata.uid, + controller=True, + block_owner_deletion=True, + ) + ], + ), + spec=k8s.V1ServiceSpec( + selector=pod.metadata.labels, + ports=[ + k8s.V1ServicePort( + name="http", + port=80, + target_port=80, + ) + ], + ), + ), + ) + + + k = KubernetesPodOperator( + task_id="test_callback", + image="alpine", + cmds=["/bin/sh"], + arguments=["-c", "echo hello world; echo Custom error > /dev/termination-log; exit 1;"], + name="test-callback", + callbacks=MyCallback, + ) + Reference ^^^^^^^^^ For further information, look at: diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index 690dece9534c9..a737e06cbefa4 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -1446,6 +1446,119 @@ def test_get_logs_but_not_for_base_container( # check that we wait for the xcom sidecar to start before extracting XCom mock_await_xcom_sidecar.assert_called_once_with(pod=pod) + @patch(HOOK_CLASS, new=MagicMock) + @patch(KUB_OP_PATH.format("find_pod")) + def test_execute_sync_callbacks(self, find_pod_mock): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + found_pods = [MagicMock(), MagicMock(), MagicMock()] + find_pod_mock.side_effect = [None] + found_pods + + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=MockKubernetesPodOperatorCallback, + ) + self.run_pod(k) + + # check on_sync_client_creation callback + mock_callbacks.on_sync_client_creation.assert_called_once() + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + + # check on_pod_creation callback + mock_callbacks.on_pod_creation.assert_called_once() + assert mock_callbacks.on_pod_creation.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[0], + } + + # check on_pod_starting callback + mock_callbacks.on_pod_starting.assert_called_once() + assert mock_callbacks.on_pod_starting.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[1], + } + + # check on_pod_completion callback + mock_callbacks.on_pod_completion.assert_called_once() + assert mock_callbacks.on_pod_completion.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + } + + # check on_pod_cleanup callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": k.pod, + } + + @patch(HOOK_CLASS, new=MagicMock) + def test_execute_async_callbacks(self): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=MockKubernetesPodOperatorCallback, + ) + k.execute_complete( + context=create_context(k), + event={ + "status": "success", + "message": TEST_SUCCESS_MESSAGE, + "name": TEST_NAME, + "namespace": TEST_NAMESPACE, + }, + ) + + # check on_operator_resuming callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": remote_pod_mock, + } + + # check on_pod_cleanup callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": remote_pod_mock, + } + class TestSuppress: def test__suppress(self, caplog): @@ -1554,9 +1667,13 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1): return remote_pod_mock @pytest.mark.parametrize("do_xcom_push", [True, False]) + @patch(KUB_OP_PATH.format("client")) + @patch(KUB_OP_PATH.format("find_pod")) @patch(KUB_OP_PATH.format("build_pod_request_obj")) @patch(KUB_OP_PATH.format("get_or_create_pod")) - def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_pod_obj, do_xcom_push): + def test_async_create_pod_should_execute_successfully( + self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push + ): """ Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired when the KubernetesPodOperator is executed in deferrable mode when deferrable=True. @@ -1584,7 +1701,7 @@ def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_p mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE context = create_context(k) - ti_mock = MagicMock() + ti_mock = MagicMock(**{"map_index": -1}) context["ti"] = ti_mock with pytest.raises(TaskDeferred) as exc: diff --git a/tests/providers/cncf/kubernetes/test_callbacks.py b/tests/providers/cncf/kubernetes/test_callbacks.py new file mode 100644 index 0000000000000..2757b8296a7d0 --- /dev/null +++ b/tests/providers/cncf/kubernetes/test_callbacks.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback + + +class MockWrapper: + mock_callbacks = MagicMock() + + @classmethod + def reset(cls): + cls.mock_callbacks.reset_mock() + + +class MockKubernetesPodOperatorCallback(KubernetesPodOperatorCallback): + """`KubernetesPodOperator` callbacks methods.""" + + @staticmethod + def on_sync_client_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_sync_client_creation(*args, **kwargs) + + @staticmethod + def on_async_client_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_async_client_creation(*args, **kwargs) + + @staticmethod + def on_pod_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_creation(*args, **kwargs) + + @staticmethod + def on_pod_starting(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_starting(*args, **kwargs) + + @staticmethod + def on_pod_completion(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_completion(*args, **kwargs) + + @staticmethod + def on_pod_cleanup(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_cleanup(*args, **kwargs) + + @staticmethod + def on_operator_resuming(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_operator_resuming(*args, **kwargs) + + @staticmethod + def progress_callback(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.progress_callback(*args, **kwargs) diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index 13f812355093a..fc09d6bb02e78 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -41,6 +41,8 @@ ) from airflow.utils.timezone import utc +from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + if TYPE_CHECKING: from pendulum import DateTime @@ -50,7 +52,9 @@ def setup_method(self): self.mock_progress_callback = mock.Mock() self.mock_kube_client = mock.Mock() self.pod_manager = PodManager( - kube_client=self.mock_kube_client, progress_callback=self.mock_progress_callback + kube_client=self.mock_kube_client, + callbacks=MockKubernetesPodOperatorCallback, + progress_callback=self.mock_progress_callback, ) def test_read_pod_logs_successfully_returns_logs(self): @@ -274,7 +278,7 @@ def test_fetch_container_logs_returning_last_timestamp( @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") - def test_fetch_container_logs_invoke_progress_callback( + def test_fetch_container_logs_invoke_deprecated_progress_callback( self, mock_read_pod_logs, mock_container_is_running ): message = "2020-10-08T14:16:17.793417674Z message" @@ -285,8 +289,30 @@ def test_fetch_container_logs_invoke_progress_callback( self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) self.mock_progress_callback.assert_has_calls([mock.call(message), mock.call(no_ts_message)]) + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") + def test_fetch_container_logs_invoke_progress_callback( + self, mock_read_pod_logs, mock_container_is_running + ): + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + message = "2020-10-08T14:16:17.793417674Z message" + no_ts_message = "notimestamp" + mock_read_pod_logs.return_value = [bytes(message, "utf-8"), bytes(no_ts_message, "utf-8")] + mock_container_is_running.return_value = False + + self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) + mock_callbacks.progress_callback.assert_has_calls( + [ + mock.call(line=message, client=self.pod_manager._client, mode="sync"), + mock.call(line=no_ts_message, client=self.pod_manager._client, mode="sync"), + ] + ) + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") def test_fetch_container_logs_failures(self, mock_container_is_running): + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks last_timestamp_string = "2020-10-08T14:18:17.793417674Z" messages = [ bytes("2020-10-08T14:16:17.793417674Z message", "utf-8"), @@ -309,6 +335,7 @@ def consumer_iter(): status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) assert status.last_log_time == cast("DateTime", pendulum.parse(last_timestamp_string)) assert self.mock_progress_callback.call_count == expected_call_count + assert mock_callbacks.progress_callback.call_count == expected_call_count @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs")