From b841b49a7b88cb42cec9c761ab97bf20adc08513 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Fri, 17 Nov 2023 23:07:19 +0200 Subject: [PATCH 01/15] Create a generic callbacks class for KubernetesPodOperator --- .../providers/cncf/kubernetes/callbacks.py | 113 ++++++++++++++++++ .../cncf/kubernetes/operators/pod.py | 37 +++++- .../providers/cncf/kubernetes/triggers/pod.py | 28 ++++- .../cncf/kubernetes/utils/pod_manager.py | 19 ++- .../cloud/triggers/kubernetes_engine.py | 8 +- .../cncf/kubernetes/operators/test_pod.py | 8 +- .../cncf/kubernetes/triggers/test_pod.py | 10 +- .../cloud/operators/test_kubernetes_engine.py | 10 +- .../cloud/triggers/test_kubernetes_engine.py | 6 +- 9 files changed, 225 insertions(+), 14 deletions(-) create mode 100644 airflow/providers/cncf/kubernetes/callbacks.py diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py new file mode 100644 index 0000000000000..881e585d617f1 --- /dev/null +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from enum import Enum + +import kubernetes.client as k8s +import kubernetes_asyncio.client as async_k8s + +client_type = k8s.CoreV1Api | async_k8s.CoreV1Api + + +class ExecutionMode(str, Enum): + """Enum class for execution mode.""" + + SYNC = "sync" + ASYNC = "async" + + +class KubernetesPodOperatorCallback: + """`KubernetesPodOperator` callbacks methods.""" + + @staticmethod + def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: + """Callback method called after creating the sync client. + + :param client: the created `kubernetes.client.CoreV1Api` client. + """ + pass + + @staticmethod + def on_async_client_creation(*, client: async_k8s.CoreV1Api, **kwargs) -> None: + """Callback method called after creating the async client. + + :param client: the created `kubernetes_asyncio.client.CoreV1Api` client. + """ + pass + + @staticmethod + def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called after creating the pod. + + :param pod: the created pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_starting(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called when the pod starts. + + :param pod: the started pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: + """Callback method called when the pod completes. + + :param pod: the completed pod. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_pod_cleanup(*, client: client_type, mode: str, **kwargs): + """Callback method called after cleaning/deleting the pod. + + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def on_operator_resuming( + *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs + ) -> None: + """Callback method called resuming the `KubernetesPodOperator` from deferred state. + + :param pod: the current state of the pod. + :param event: the returned event from the Trigger. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass + + @staticmethod + def progress_callback(*, line: str, client: client_type, mode: str, **kwargs) -> None: + """Callback method to process pod container logs. + + :param line: the read line of log. + :param client: the Kubernetes client that can be used in the callback. + :param mode: the current execution mode, it's one of (`sync`, `async`). + """ + pass diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 3cacb126d07d2..4684e6e015483 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -28,7 +28,7 @@ from collections.abc import Container from contextlib import AbstractContextManager from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, TypeVar from kubernetes.client import CoreV1Api, V1Pod, models as k8s from kubernetes.stream import stream @@ -50,6 +50,7 @@ convert_volume, convert_volume_mount, ) +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook from airflow.providers.cncf.kubernetes.pod_generator import PodGenerator from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger @@ -82,6 +83,9 @@ KUBE_CONFIG_ENV_VAR = "KUBECONFIG" +C = TypeVar("C", bound=KubernetesPodOperatorCallback) + + def _rand_str(num): """Generate random lowercase alphanumeric string of length num. @@ -248,7 +252,10 @@ class KubernetesPodOperator(BaseOperator): Default value is "File" :param active_deadline_seconds: The active_deadline_seconds which matches to active_deadline_seconds in V1PodSpec. + :param callbacks: KubernetesPodOperatorCallback instance contains the callbacks methods on different step + of KubernetesPodOperator. :param progress_callback: Callback function for receiving k8s container logs. + `progress_callback` is deprecated, please use :param `callbacks` instead. """ # This field can be overloaded at the instance level via base_container_name @@ -335,6 +342,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, + callbacks: C = KubernetesPodOperatorCallback, progress_callback: Callable[[str], None] | None = None, **kwargs, ) -> None: @@ -439,6 +447,7 @@ def __init__( self._config_dict: dict | None = None # TODO: remove it when removing convert_config_file_to_dict self._progress_callback = progress_callback + self.callbacks = callbacks @cached_property def _incluster_namespace(self): @@ -530,7 +539,9 @@ def hook(self) -> PodOperatorHookProtocol: @cached_property def client(self) -> CoreV1Api: - return self.hook.core_v1_client + client = self.hook.core_v1_client + self.callbacks.on_sync_client_creation(client=client) + return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: """Return an already-running pod for this task instance if one exists.""" @@ -607,7 +618,13 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) + self.callbacks.on_pod_creation(pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC) self.await_pod_start(pod=self.pod) + self.callbacks.on_pod_starting( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.get_logs: self.pod_manager.fetch_requested_container_logs( @@ -622,6 +639,12 @@ def execute_sync(self, context: Context): pod=self.pod, container_name=self.base_container_name ) + self.callbacks.on_pod_completion( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) + if self.do_xcom_push: self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) result = self.extract_xcom(pod=self.pod) @@ -643,6 +666,11 @@ def execute_async(self, context: Context): pod_request_obj=self.pod_request_obj, context=context, ) + self.callbacks.on_pod_creation( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) @@ -678,6 +706,9 @@ def execute_complete(self, context: Context, event: dict, **kwargs): event["name"], event["namespace"], ) + self.callbacks.on_operator_resuming( + pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC + ) if event["status"] in ("error", "failed", "timeout"): # fetch some logs when pod is failed if self.get_logs: @@ -702,6 +733,7 @@ def execute_complete(self, context: Context, event: dict, **kwargs): pod=pod, remote_pod=pod, ) + self.callbacks.on_pod_cleanup(client=self.client, mode=ExecutionMode.SYNC) def write_logs(self, pod: k8s.V1Pod): try: @@ -726,6 +758,7 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs): pod=pod, remote_pod=remote_pod, ) + self.callbacks.on_pod_cleanup(client=self.client, mode=ExecutionMode.SYNC) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): istio_enabled = self.is_istio_enabled(remote_pod) diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index 5eda4242769a0..760a59f678644 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -21,9 +21,10 @@ import warnings from asyncio import CancelledError from enum import Enum -from typing import TYPE_CHECKING, Any, AsyncIterator +from typing import TYPE_CHECKING, Any, AsyncIterator, TypeVar from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -32,6 +33,9 @@ from kubernetes_asyncio.client.models import V1Pod +C = TypeVar("C", bound=KubernetesPodOperatorCallback) + + class ContainerState(str, Enum): """ Possible container states. @@ -86,6 +90,7 @@ def __init__( startup_check_interval: int = 1, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, + callbacks: C = KubernetesPodOperatorCallback, ): super().__init__() self.pod_name = pod_name @@ -100,6 +105,7 @@ def __init__( self.get_logs = get_logs self.startup_timeout = startup_timeout self.startup_check_interval = startup_check_interval + self.callbacks = callbacks if should_delete_pod is not None: warnings.warn( @@ -136,6 +142,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "trigger_start_time": self.trigger_start_time, "should_delete_pod": self.should_delete_pod, "on_finish_action": self.on_finish_action.value, + "callbacks": self.callbacks, }, ) @@ -143,6 +150,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pod status and yield a TriggerEvent.""" hook = self._get_async_hook() self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) + _is_starting_callback_called = False try: while True: pod = await hook.get_pod( @@ -189,9 +197,21 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Sleeping for %s seconds.", self.startup_check_interval) await asyncio.sleep(self.startup_check_interval) else: + if not _is_starting_callback_called: + # if the trigger fails and re-run on a different triggerer, this callback could + # be called again + self.callbacks.on_pod_starting( + pod=pod, + client=self._get_async_hook().core_v1_client, + mode=ExecutionMode.ASYNC, + ) + _is_starting_callback_called = True self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) else: + self.callbacks.on_pod_completion( + pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC + ) yield TriggerEvent( { "name": self.pod_name, @@ -215,6 +235,11 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] name=self.pod_name, namespace=self.pod_namespace, ) + self.callbacks.on_pod_cleanup( + pod=await hook.get_pod(name=self.pod_name, namespace=self.pod_namespace), + client=self._get_async_hook().core_v1_client, + mode=ExecutionMode.ASYNC, + ) yield TriggerEvent( { "name": self.pod_name, @@ -242,6 +267,7 @@ def _get_async_hook(self) -> AsyncKubernetesHook: config_file=self.config_file, cluster_context=self.cluster_context, ) + self.callbacks.on_async_client_creation(client=self._hook.core_v1_client) return self._hook def define_container_state(self, pod: V1Pod) -> ContainerState: diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index b215d7b68abf0..3b91c0186dde6 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -38,6 +38,7 @@ from urllib3.exceptions import HTTPError as BaseHTTPError from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.pod_generator import PodDefaults from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timezone import utcnow @@ -277,18 +278,22 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, + callbacks: KubernetesPodOperatorCallback = KubernetesPodOperatorCallback(), progress_callback: Callable[[str], None] | None = None, ): """ Create the launcher. :param kube_client: kubernetes client + :param callbacks: :param progress_callback: Callback function invoked when fetching container log. + This parameter is deprecated, please use ```` """ super().__init__() self._client = kube_client self._progress_callback = progress_callback self._watch = watch.Watch() + self._callbacks = callbacks def run_pod_async(self, pod: V1Pod, **kwargs) -> V1Pod: """Run POD asynchronously.""" @@ -419,9 +424,12 @@ def consume_logs(*, since_time: DateTime | None = None) -> DateTime | None: message_timestamp = line_timestamp progress_callback_lines.append(line) else: # previous log line is complete - if self._progress_callback: - for line in progress_callback_lines: + for line in progress_callback_lines: + if self._progress_callback: self._progress_callback(line) + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp message_to_log = message @@ -432,9 +440,12 @@ def consume_logs(*, since_time: DateTime | None = None) -> DateTime | None: progress_callback_lines.append(line) finally: # log the last line and update the last_captured_timestamp - if self._progress_callback: - for line in progress_callback_lines: + for line in progress_callback_lines: + if self._progress_callback: self._progress_callback(line) + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp except BaseHTTPError: diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index 1fbaef72a9ce0..6bd08e61686cc 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -19,11 +19,12 @@ import asyncio import warnings -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, TypeVar from google.cloud.container_v1.types import Operation from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction try: @@ -37,6 +38,8 @@ if TYPE_CHECKING: from datetime import datetime +C = TypeVar("C", bound=KubernetesPodOperatorCallback) + class GKEStartPodTrigger(KubernetesPodTrigger): """ @@ -80,6 +83,7 @@ def __init__( startup_timeout: int = 120, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, + callbacks: C = KubernetesPodOperatorCallback, *args, **kwargs, ): @@ -100,6 +104,7 @@ def __init__( self.in_cluster = in_cluster self.get_logs = get_logs self.startup_timeout = startup_timeout + self.callbacks = callbacks if should_delete_pod is not None: warnings.warn( @@ -134,6 +139,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "base_container_name": self.base_container_name, "should_delete_pod": self.should_delete_pod, "on_finish_action": self.on_finish_action.value, + "callbacks": self.callbacks, }, ) diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index 5e9cbbb9168a0..d5f566112e4fe 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -1487,9 +1487,13 @@ def run_pod_async(self, operator: KubernetesPodOperator, map_index: int = -1): return remote_pod_mock @pytest.mark.parametrize("do_xcom_push", [True, False]) + @patch(KUB_OP_PATH.format("client")) + @patch(KUB_OP_PATH.format("find_pod")) @patch(KUB_OP_PATH.format("build_pod_request_obj")) @patch(KUB_OP_PATH.format("get_or_create_pod")) - def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_pod_obj, do_xcom_push): + def test_async_create_pod_should_execute_successfully( + self, mocked_pod, mocked_pod_obj, mocked_found_pod, mocked_client, do_xcom_push + ): """ Asserts that a task is deferred and the KubernetesCreatePodTrigger will be fired when the KubernetesPodOperator is executed in deferrable mode when deferrable=True. @@ -1517,7 +1521,7 @@ def test_async_create_pod_should_execute_successfully(self, mocked_pod, mocked_p mocked_pod.return_value.metadata.namespace = TEST_NAMESPACE context = create_context(k) - ti_mock = MagicMock() + ti_mock = MagicMock(**{"map_index": -1}) context["ti"] = ti_mock with pytest.raises(TaskDeferred) as exc: diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 5694698d10cea..ddae987f41597 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -26,6 +26,7 @@ import pytest from kubernetes.client import models as k8s +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.triggers.pod import ContainerState, KubernetesPodTrigger from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase from airflow.triggers.base import TriggerEvent @@ -47,6 +48,10 @@ ON_FINISH_ACTION = "delete_pod" +async def async_mock(): + return mock.MagicMock() + + @pytest.fixture def trigger(): return KubernetesPodTrigger( @@ -90,6 +95,7 @@ def test_serialize(self, trigger): "trigger_start_time": TRIGGER_START_TIME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": ON_FINISH_ACTION == "delete_pod", + "callbacks": KubernetesPodOperatorCallback, } @pytest.mark.asyncio @@ -217,7 +223,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and Test that KubernetesPodTrigger fires the correct event in case if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() + mock_hook.return_value.get_pod.side_effect = [CancelledError(), async_mock()] mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) @@ -263,7 +269,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_wit Test that KubernetesPodTrigger fires the correct event if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() + mock_hook.return_value.get_pod.side_effect = [CancelledError(), async_mock()] mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index 5f91fd9e4ea99..ba574a1ea6800 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -440,6 +440,8 @@ def setup_method(self): self.gke_op._ssl_ca_cert = SSL_CA_CERT @mock.patch.dict(os.environ, {}) + @mock.patch(KUB_OP_PATH.format("client")) + @mock.patch(KUB_OP_PATH.format("find_pod")) @mock.patch(KUB_OP_PATH.format("build_pod_request_obj")) @mock.patch(KUB_OP_PATH.format("get_or_create_pod")) @mock.patch( @@ -448,7 +450,13 @@ def setup_method(self): ) @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") def test_async_create_pod_should_execute_successfully( - self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj + self, + fetch_cluster_info_mock, + get_con_mock, + mocked_pod, + mocked_pod_obj, + mocked_found_pod, + mocked_client, ): """ Asserts that a task is deferred and the GKEStartPodTrigger will be fired diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 41f0a67211fc2..943019108f7c9 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -231,7 +231,11 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully( """ Test that GKEStartPodTrigger fires the correct event in case if the task was cancelled. """ - mock_hook.return_value.get_pod.side_effect = CancelledError() + + async def async_mock(): + return mock.MagicMock() + + mock_hook.return_value.get_pod.side_effect = [CancelledError(), async_mock()] mock_hook.return_value.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.return_value.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) From 2797aefd773d835d0140b936f1278f6114448c12 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 19:38:18 +0200 Subject: [PATCH 02/15] Trigger tests with old-style union --- airflow/providers/cncf/kubernetes/callbacks.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py index 881e585d617f1..76baf414a816c 100644 --- a/airflow/providers/cncf/kubernetes/callbacks.py +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -17,11 +17,12 @@ from __future__ import annotations from enum import Enum +from typing import Union import kubernetes.client as k8s import kubernetes_asyncio.client as async_k8s -client_type = k8s.CoreV1Api | async_k8s.CoreV1Api +client_type = Union[k8s.CoreV1Api, async_k8s.CoreV1Api] class ExecutionMode(str, Enum): From 6b884f09ce03522488e0432d62e0d02178f222fe Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 20:26:49 +0200 Subject: [PATCH 03/15] Fix GCP K8S test --- tests/providers/google/cloud/triggers/test_kubernetes_engine.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 943019108f7c9..eeb69559ff9ef 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -27,6 +27,7 @@ from google.cloud.container_v1.types import Operation from kubernetes.client import models as k8s +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import ContainerState from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger from airflow.triggers.base import TriggerEvent @@ -101,6 +102,7 @@ def test_serialize_should_execute_successfully(self, trigger): "base_container_name": BASE_CONTAINER_NAME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": SHOULD_DELETE_POD, + "callbacks": KubernetesPodOperatorCallback, } @pytest.mark.asyncio From 3ddfd82e02800146b0f56b7885c0bc363603bd6c Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 21:36:13 +0200 Subject: [PATCH 04/15] Fix callback param type and cleanup calls --- airflow/providers/cncf/kubernetes/callbacks.py | 3 ++- airflow/providers/cncf/kubernetes/operators/pod.py | 9 +++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py index 76baf414a816c..d32c1e29f4dae 100644 --- a/airflow/providers/cncf/kubernetes/callbacks.py +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -82,9 +82,10 @@ def on_pod_completion(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwarg pass @staticmethod - def on_pod_cleanup(*, client: client_type, mode: str, **kwargs): + def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): """Callback method called after cleaning/deleting the pod. + :param pod: the completed pod. :param client: the Kubernetes client that can be used in the callback. :param mode: the current execution mode, it's one of (`sync`, `async`). """ diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 4684e6e015483..c4b651b189358 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -83,7 +83,7 @@ KUBE_CONFIG_ENV_VAR = "KUBECONFIG" -C = TypeVar("C", bound=KubernetesPodOperatorCallback) +C = TypeVar("C", bound=type[KubernetesPodOperatorCallback]) def _rand_str(num): @@ -653,10 +653,12 @@ def execute_sync(self, context: Context): self.pod, istio_enabled, self.base_container_name ) finally: + pod_to_clean = self.pod or self.pod_request_obj self.cleanup( - pod=self.pod or self.pod_request_obj, + pod=pod_to_clean, remote_pod=self.remote_pod, ) + self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) if self.do_xcom_push: return result @@ -733,7 +735,6 @@ def execute_complete(self, context: Context, event: dict, **kwargs): pod=pod, remote_pod=pod, ) - self.callbacks.on_pod_cleanup(client=self.client, mode=ExecutionMode.SYNC) def write_logs(self, pod: k8s.V1Pod): try: @@ -758,7 +759,7 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs): pod=pod, remote_pod=remote_pod, ) - self.callbacks.on_pod_cleanup(client=self.client, mode=ExecutionMode.SYNC) + self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): istio_enabled = self.is_istio_enabled(remote_pod) From d2452984f71d8e3399366f6ef76cc96f6d0a52b2 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 23:05:36 +0200 Subject: [PATCH 05/15] Some fixes and add unit tests --- .../cncf/kubernetes/operators/pod.py | 4 +- .../providers/cncf/kubernetes/triggers/pod.py | 17 ++- .../cncf/kubernetes/utils/pod_manager.py | 4 +- .../cncf/kubernetes/operators/test_pod.py | 113 ++++++++++++++++++ .../cncf/kubernetes/test_callbacks.py | 65 ++++++++++ .../cncf/kubernetes/triggers/test_pod.py | 52 ++++++++ .../cncf/kubernetes/utils/test_pod_manager.py | 28 ++++- 7 files changed, 278 insertions(+), 5 deletions(-) create mode 100644 tests/providers/cncf/kubernetes/test_callbacks.py diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index c4b651b189358..4c92b4dfd74b9 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -525,7 +525,9 @@ def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool @cached_property def pod_manager(self) -> PodManager: - return PodManager(kube_client=self.client, progress_callback=self._progress_callback) + return PodManager( + kube_client=self.client, callbacks=self.callbacks, progress_callback=self._progress_callback + ) @cached_property def hook(self) -> PodOperatorHookProtocol: diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index 760a59f678644..e8d6853b088a5 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -33,7 +33,7 @@ from kubernetes_asyncio.client.models import V1Pod -C = TypeVar("C", bound=KubernetesPodOperatorCallback) +C = TypeVar("C", bound=type[KubernetesPodOperatorCallback]) class ContainerState(str, Enum): @@ -165,6 +165,15 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.debug("Container %s status: %s", self.base_container_name, container_state) if container_state == ContainerState.TERMINATED: + if not _is_starting_callback_called: + self.callbacks.on_pod_starting( + pod=pod, + client=self._get_async_hook().core_v1_client, + mode=ExecutionMode.ASYNC, + ) + self.callbacks.on_pod_completion( + pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC + ) yield TriggerEvent( { "name": self.pod_name, @@ -209,6 +218,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) else: + if not _is_starting_callback_called: + self.callbacks.on_pod_starting( + pod=pod, + client=self._get_async_hook().core_v1_client, + mode=ExecutionMode.ASYNC, + ) self.callbacks.on_pod_completion( pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC ) diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index 2b0cbf0450c46..aa6d28aa404da 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -51,6 +51,8 @@ from kubernetes.client.models.v1_pod import V1Pod from urllib3.response import HTTPResponse + from airflow.providers.cncf.kubernetes.operators.pod import C + EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__" """ Sentinel for no xcom result. @@ -288,7 +290,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: KubernetesPodOperatorCallback = KubernetesPodOperatorCallback(), + callbacks: C = KubernetesPodOperatorCallback, progress_callback: Callable[[str], None] | None = None, ): """ diff --git a/tests/providers/cncf/kubernetes/operators/test_pod.py b/tests/providers/cncf/kubernetes/operators/test_pod.py index d5f566112e4fe..0044fc9453a9b 100644 --- a/tests/providers/cncf/kubernetes/operators/test_pod.py +++ b/tests/providers/cncf/kubernetes/operators/test_pod.py @@ -1379,6 +1379,119 @@ def test_get_logs_but_not_for_base_container( # check that we wait for the xcom sidecar to start before extracting XCom mock_await_xcom_sidecar.assert_called_once_with(pod=pod) + @patch(HOOK_CLASS, new=MagicMock) + @patch(KUB_OP_PATH.format("find_pod")) + def test_execute_sync_callbacks(self, find_pod_mock): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + found_pods = [MagicMock(), MagicMock(), MagicMock()] + find_pod_mock.side_effect = [None] + found_pods + + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=MockKubernetesPodOperatorCallback, + ) + self.run_pod(k) + + # check on_sync_client_creation callback + mock_callbacks.on_sync_client_creation.assert_called_once() + assert mock_callbacks.on_sync_client_creation.call_args.kwargs == {"client": k.client} + + # check on_pod_creation callback + mock_callbacks.on_pod_creation.assert_called_once() + assert mock_callbacks.on_pod_creation.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[0], + } + + # check on_pod_starting callback + mock_callbacks.on_pod_starting.assert_called_once() + assert mock_callbacks.on_pod_starting.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[1], + } + + # check on_pod_completion callback + mock_callbacks.on_pod_completion.assert_called_once() + assert mock_callbacks.on_pod_completion.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": found_pods[2], + } + + # check on_pod_cleanup callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": k.pod, + } + + @patch(HOOK_CLASS, new=MagicMock) + def test_execute_async_callbacks(self): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + remote_pod_mock = MagicMock() + remote_pod_mock.status.phase = "Succeeded" + self.await_pod_mock.return_value = remote_pod_mock + + k = KubernetesPodOperator( + namespace="default", + image="ubuntu:16.04", + cmds=["bash", "-cx"], + arguments=["echo 10"], + labels={"foo": "bar"}, + name="test", + task_id="task", + do_xcom_push=False, + callbacks=MockKubernetesPodOperatorCallback, + ) + k.execute_complete( + context=create_context(k), + event={ + "status": "success", + "message": TEST_SUCCESS_MESSAGE, + "name": TEST_NAME, + "namespace": TEST_NAMESPACE, + }, + ) + + # check on_operator_resuming callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": remote_pod_mock, + } + + # check on_pod_cleanup callback + mock_callbacks.on_pod_cleanup.assert_called_once() + assert mock_callbacks.on_pod_cleanup.call_args.kwargs == { + "client": k.client, + "mode": ExecutionMode.SYNC, + "pod": remote_pod_mock, + } + class TestSuppress: def test__suppress(self, caplog): diff --git a/tests/providers/cncf/kubernetes/test_callbacks.py b/tests/providers/cncf/kubernetes/test_callbacks.py new file mode 100644 index 0000000000000..2757b8296a7d0 --- /dev/null +++ b/tests/providers/cncf/kubernetes/test_callbacks.py @@ -0,0 +1,65 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback + + +class MockWrapper: + mock_callbacks = MagicMock() + + @classmethod + def reset(cls): + cls.mock_callbacks.reset_mock() + + +class MockKubernetesPodOperatorCallback(KubernetesPodOperatorCallback): + """`KubernetesPodOperator` callbacks methods.""" + + @staticmethod + def on_sync_client_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_sync_client_creation(*args, **kwargs) + + @staticmethod + def on_async_client_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_async_client_creation(*args, **kwargs) + + @staticmethod + def on_pod_creation(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_creation(*args, **kwargs) + + @staticmethod + def on_pod_starting(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_starting(*args, **kwargs) + + @staticmethod + def on_pod_completion(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_completion(*args, **kwargs) + + @staticmethod + def on_pod_cleanup(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_pod_cleanup(*args, **kwargs) + + @staticmethod + def on_operator_resuming(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.on_operator_resuming(*args, **kwargs) + + @staticmethod + def progress_callback(*args, **kwargs) -> None: + MockWrapper.mock_callbacks.progress_callback(*args, **kwargs) diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index ddae987f41597..dd9697a5463f0 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -375,3 +375,55 @@ async def test_run_loop_return_timeout_event( ) == actual ) + + @pytest.mark.asyncio + @mock.patch(f"{TRIGGER_PATH}.define_container_state") + @mock.patch(f"{TRIGGER_PATH}._get_async_hook") + async def test_callbacks(self, mock_hook, mock_method): + from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode + + from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks + + pods_mock = [ + self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.PENDING))), + self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.RUNNING))), + self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.SUCCEEDED))), + ] + mock_hook.return_value.get_pod.side_effect = pods_mock + mock_method.side_effect = [ContainerState.WAITING, ContainerState.RUNNING, ContainerState.TERMINATED] + + k = KubernetesPodTrigger( + pod_name=POD_NAME, + pod_namespace=NAMESPACE, + base_container_name=BASE_CONTAINER_NAME, + kubernetes_conn_id=CONN_ID, + poll_interval=POLL_INTERVAL, + cluster_context=CLUSTER_CONTEXT, + config_file=CONFIG_FILE, + in_cluster=IN_CLUSTER, + get_logs=GET_LOGS, + startup_timeout=STARTUP_TIMEOUT_SECS, + trigger_start_time=TRIGGER_START_TIME, + on_finish_action=ON_FINISH_ACTION, + callbacks=MockKubernetesPodOperatorCallback, + ) + await k.run().asend(None) + + # check on_pod_starting callback + mock_callbacks.on_pod_starting.assert_called_once() + assert mock_callbacks.on_pod_starting.call_args.kwargs == { + "client": k._get_async_hook().core_v1_client, + "mode": ExecutionMode.ASYNC, + "pod": pods_mock[1].result(), + } + + # check on_pod_completion callback + mock_callbacks.on_pod_completion.assert_called_once() + assert mock_callbacks.on_pod_completion.call_args.kwargs == { + "client": k._get_async_hook().core_v1_client, + "mode": ExecutionMode.ASYNC, + "pod": pods_mock[2].result(), + } diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index a4232da30eb3c..abd68eae164fa 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -43,13 +43,18 @@ ) from airflow.utils.timezone import utc +from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper + class TestPodManager: def setup_method(self): self.mock_progress_callback = mock.Mock() self.mock_kube_client = mock.Mock() + self.mock_callbacks = MockWrapper.mock_callbacks self.pod_manager = PodManager( - kube_client=self.mock_kube_client, progress_callback=self.mock_progress_callback + kube_client=self.mock_kube_client, + callbacks=MockKubernetesPodOperatorCallback, + progress_callback=self.mock_progress_callback, ) def test_read_pod_logs_successfully_returns_logs(self): @@ -273,7 +278,7 @@ def test_fetch_container_logs_returning_last_timestamp( @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") - def test_fetch_container_logs_invoke_progress_callback( + def test_fetch_container_logs_invoke_deprecated_progress_callback( self, mock_read_pod_logs, mock_container_is_running ): message = "2020-10-08T14:16:17.793417674Z message" @@ -284,6 +289,24 @@ def test_fetch_container_logs_invoke_progress_callback( self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) self.mock_progress_callback.assert_has_calls([mock.call(message), mock.call(no_ts_message)]) + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") + def test_fetch_container_logs_invoke_progress_callback( + self, mock_read_pod_logs, mock_container_is_running + ): + message = "2020-10-08T14:16:17.793417674Z message" + no_ts_message = "notimestamp" + mock_read_pod_logs.return_value = [bytes(message, "utf-8"), bytes(no_ts_message, "utf-8")] + mock_container_is_running.return_value = False + + self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) + self.mock_callbacks.progress_callback.assert_has_calls( + [ + mock.call(line=message, client=self.pod_manager._client, mode="sync"), + mock.call(line=no_ts_message, client=self.pod_manager._client, mode="sync"), + ] + ) + @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") def test_fetch_container_logs_failures(self, mock_container_is_running): last_timestamp_string = "2020-10-08T14:18:17.793417674Z" @@ -308,6 +331,7 @@ def consumer_iter(): status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) assert status.last_log_time == cast(DateTime, pendulum.parse(last_timestamp_string)) assert self.mock_progress_callback.call_count == expected_call_count + assert self.mock_callbacks.progress_callback.call_count == expected_call_count @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") From f722e40f43f04c7e6b2798a7a784c605be98c9dc Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 23:43:14 +0200 Subject: [PATCH 06/15] Replace type by Type --- airflow/providers/cncf/kubernetes/operators/pod.py | 4 ++-- airflow/providers/cncf/kubernetes/triggers/pod.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 4c92b4dfd74b9..60cabdc29dd27 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -28,7 +28,7 @@ from collections.abc import Container from contextlib import AbstractContextManager from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type, TypeVar from kubernetes.client import CoreV1Api, V1Pod, models as k8s from kubernetes.stream import stream @@ -83,7 +83,7 @@ KUBE_CONFIG_ENV_VAR = "KUBECONFIG" -C = TypeVar("C", bound=type[KubernetesPodOperatorCallback]) +C = TypeVar("C", bound=Type[KubernetesPodOperatorCallback]) def _rand_str(num): diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index e8d6853b088a5..32aedaac377a6 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -21,7 +21,7 @@ import warnings from asyncio import CancelledError from enum import Enum -from typing import TYPE_CHECKING, Any, AsyncIterator, TypeVar +from typing import TYPE_CHECKING, Any, AsyncIterator, Type, TypeVar from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback @@ -33,7 +33,7 @@ from kubernetes_asyncio.client.models import V1Pod -C = TypeVar("C", bound=type[KubernetesPodOperatorCallback]) +C = TypeVar("C", bound=Type[KubernetesPodOperatorCallback]) class ContainerState(str, Enum): From c4ae081a5f469a975e7036014cc6ab43f35748bd Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Tue, 28 Nov 2023 23:55:20 +0200 Subject: [PATCH 07/15] Reset mock_callbacks in pod manager tests --- .../providers/cncf/kubernetes/utils/test_pod_manager.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py index abd68eae164fa..e8dca262e0b24 100644 --- a/tests/providers/cncf/kubernetes/utils/test_pod_manager.py +++ b/tests/providers/cncf/kubernetes/utils/test_pod_manager.py @@ -50,7 +50,6 @@ class TestPodManager: def setup_method(self): self.mock_progress_callback = mock.Mock() self.mock_kube_client = mock.Mock() - self.mock_callbacks = MockWrapper.mock_callbacks self.pod_manager = PodManager( kube_client=self.mock_kube_client, callbacks=MockKubernetesPodOperatorCallback, @@ -294,13 +293,15 @@ def test_fetch_container_logs_invoke_deprecated_progress_callback( def test_fetch_container_logs_invoke_progress_callback( self, mock_read_pod_logs, mock_container_is_running ): + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks message = "2020-10-08T14:16:17.793417674Z message" no_ts_message = "notimestamp" mock_read_pod_logs.return_value = [bytes(message, "utf-8"), bytes(no_ts_message, "utf-8")] mock_container_is_running.return_value = False self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) - self.mock_callbacks.progress_callback.assert_has_calls( + mock_callbacks.progress_callback.assert_has_calls( [ mock.call(line=message, client=self.pod_manager._client, mode="sync"), mock.call(line=no_ts_message, client=self.pod_manager._client, mode="sync"), @@ -309,6 +310,8 @@ def test_fetch_container_logs_invoke_progress_callback( @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") def test_fetch_container_logs_failures(self, mock_container_is_running): + MockWrapper.reset() + mock_callbacks = MockWrapper.mock_callbacks last_timestamp_string = "2020-10-08T14:18:17.793417674Z" messages = [ bytes("2020-10-08T14:16:17.793417674Z message", "utf-8"), @@ -331,7 +334,7 @@ def consumer_iter(): status = self.pod_manager.fetch_container_logs(mock.MagicMock(), mock.MagicMock(), follow=True) assert status.last_log_time == cast(DateTime, pendulum.parse(last_timestamp_string)) assert self.mock_progress_callback.call_count == expected_call_count - assert self.mock_callbacks.progress_callback.call_count == expected_call_count + assert mock_callbacks.progress_callback.call_count == expected_call_count @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.container_is_running") @mock.patch("airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager.read_pod_logs") From 6173c1b237dc366d955d342038d720b9348ffd1c Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Wed, 10 Jan 2024 00:42:48 +0100 Subject: [PATCH 08/15] Fix static checks --- airflow/providers/cncf/kubernetes/operators/pod.py | 6 ++---- airflow/providers/cncf/kubernetes/triggers/pod.py | 7 ++----- airflow/providers/cncf/kubernetes/utils/pod_manager.py | 3 +-- .../providers/google/cloud/triggers/kubernetes_engine.py | 6 ++---- 4 files changed, 7 insertions(+), 15 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 3e2e176f804f8..838ddb813735e 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -27,7 +27,7 @@ from collections.abc import Container from contextlib import AbstractContextManager from functools import cached_property -from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence, Type, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence from kubernetes.client import CoreV1Api, V1Pod, models as k8s from kubernetes.stream import stream @@ -85,8 +85,6 @@ KUBE_CONFIG_ENV_VAR = "KUBECONFIG" -C = TypeVar("C", bound=Type[KubernetesPodOperatorCallback]) - class PodReattachFailure(AirflowException): """When we expect to be able to find a pod but cannot.""" @@ -295,7 +293,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: C = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, progress_callback: Callable[[str], None] | None = None, **kwargs, ) -> None: diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index e82400f683901..aee7c7b72c879 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -23,7 +23,7 @@ from asyncio import CancelledError from enum import Enum from functools import cached_property -from typing import TYPE_CHECKING, Any, AsyncIterator, Type, TypeVar +from typing import TYPE_CHECKING, Any, AsyncIterator from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback @@ -35,9 +35,6 @@ from kubernetes_asyncio.client.models import V1Pod -C = TypeVar("C", bound=Type[KubernetesPodOperatorCallback]) - - class ContainerState(str, Enum): """ Possible container states. @@ -92,7 +89,7 @@ def __init__( startup_check_interval: int = 1, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: C = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, ): super().__init__() self.pod_name = pod_name diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index f67fea3f66398..d08677fad1b28 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -51,7 +51,6 @@ from kubernetes.client.models.v1_pod import V1Pod from urllib3.response import HTTPResponse - from airflow.providers.cncf.kubernetes.operators.pod import C EMPTY_XCOM_RESULT = "__airflow_xcom_result_empty__" """ @@ -290,7 +289,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: C = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, progress_callback: Callable[[str], None] | None = None, ): """ diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index c03a0bcb0531b..f91a42ef37ea9 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -20,7 +20,7 @@ import asyncio import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, TypeVar +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.cloud.container_v1.types import Operation @@ -39,8 +39,6 @@ if TYPE_CHECKING: from datetime import datetime -C = TypeVar("C", bound=KubernetesPodOperatorCallback) - class GKEStartPodTrigger(KubernetesPodTrigger): """ @@ -84,7 +82,7 @@ def __init__( startup_timeout: int = 120, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: C = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, *args, **kwargs, ): From b1755afb973009a5cea7a855f592409d0cfaf057 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Thu, 11 Jan 2024 00:49:25 +0100 Subject: [PATCH 09/15] Add a doc paragraph for the new callbacks --- .../operators.rst | 93 +++++++++++++++++++ 1 file changed, 93 insertions(+) diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index 4c1dd9ac86d2f..fabb96dac44f8 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -193,6 +193,99 @@ included in the exception message if the task fails. Read more on termination-log `here `__. +KubernetesPodOperator callbacks +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The :class:`~airflow.providers.cncf.kubernetes.operators.kubernetes_pod.KubernetesPodOperator` supports different +callbacks that can be used to trigger actions during the lifecycle of the pod. In order to use them, you need to +create a subclass of :class:`~airflow.providers.cncf.kubernetes.callbacks.KubernetesPodOperatorCallback` and override +the callbacks methods you want to use. Then you can pass your callback class to the operator using the ``callbacks`` +parameter. + +The following callbacks are supported: + +* on_sync_client_creation: called after creating the sync client +* on_async_client_creation: called after creating the async client +* on_pod_creation: called after creating the pod +* on_pod_starting: called after the pod starts +* on_pod_completion: called when the pod completes +* on_pod_cleanup: called after cleaning/deleting the pod +* on_operator_resuming: when resuming the task from deferred state +* progress_callback: called on each line of containers logs + +Example: +~~~~~~~~ +.. code-block:: python + + import kubernetes.client as k8s + import kubernetes_asyncio.client as async_k8s + + from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator + from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback + + + class MyCallback(KubernetesPodOperatorCallback): + @staticmethod + def on_pod_creation(*, pod: k8s.V1Pod, client: k8s.CoreV1Api, mode: str, **kwargs) -> None: + # currently, the pod is always created when the task is running, so the mode is always "sync" + client.create_namespaced_service( + namespace=pod.metadata.namespace, + body=k8s.V1Service( + metadata=k8s.V1ObjectMeta( + name=pod.metadata.name, + labels=pod.metadata.labels, + owner_references=[ + k8s.V1OwnerReference( + api_version=pod.api_version, + kind=pod.kind, + name=pod.metadata.name, + uid=pod.metadata.uid, + controller=True, + block_owner_deletion=True, + ) + ], + ), + spec=k8s.V1ServiceSpec( + selector=pod.metadata.labels, + ports=[ + k8s.V1ServicePort( + name="http", + port=80, + target_port=80, + ) + ], + ), + ), + ) + + @staticmethod + def on_pod_starting( + *, pod: k8s.V1Pod, client: k8s.CoreV1Api | async_k8s.CoreV1Api, mode: str, **kwargs + ) -> None: + # this callback can be called in sync or async mode, so we need to handle both cases and avoid blocking the event loop + import asyncio + + def _some_sync_function(): + ... + + async def _some_async_function(): + ... + + if mode == "sync": + _some_sync_function() + else: + asyncio.get_event_loop().run_until_complete(_some_async_function()) + + + k = KubernetesPodOperator( + task_id="test_callback", + image="alpine", + cmds=["/bin/sh"], + arguments=["-c", "echo hello world; echo Custom error > /dev/termination-log; exit 1;"], + name="test-callback", + callbacks=MyCallback, + ) + Reference ^^^^^^^^^ For further information, look at: From 1860881c98cdd49742cf4cb057462479e47fef7e Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Thu, 11 Jan 2024 00:50:17 +0100 Subject: [PATCH 10/15] Add a check for cncf-kuberntes version in google provider --- .../providers/cncf/kubernetes/callbacks.py | 2 +- .../cloud/operators/kubernetes_engine.py | 21 ++++++++++++++++++- .../cloud/triggers/kubernetes_engine.py | 21 ++++++++++++++++--- .../cloud/operators/test_kubernetes_engine.py | 2 ++ .../cloud/triggers/test_kubernetes_engine.py | 13 +++++++++++- 5 files changed, 53 insertions(+), 6 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py index d32c1e29f4dae..0ed2fb0b2116f 100644 --- a/airflow/providers/cncf/kubernetes/callbacks.py +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -95,7 +95,7 @@ def on_pod_cleanup(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs): def on_operator_resuming( *, pod: k8s.V1Pod, event: dict, client: client_type, mode: str, **kwargs ) -> None: - """Callback method called resuming the `KubernetesPodOperator` from deferred state. + """Callback method called when resuming the `KubernetesPodOperator` from deferred state. :param pod: the current state of the pod. :param event: the returned event from the Trigger. diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index 233a9d62a049a..974e62e64353d 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -41,7 +41,13 @@ KubernetesEnginePodLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger +from airflow.providers.google.cloud.triggers.kubernetes_engine import ( + GKEOperationTrigger, + GKEStartPodTrigger, + callbacks_type, + default_callbacks, + is_generic_callbacks_supported, +) from airflow.utils.timezone import utcnow if TYPE_CHECKING: @@ -456,6 +462,7 @@ def __init__( regional: bool | None = None, on_finish_action: str | None = None, is_delete_operator_pod: bool | None = None, + callbacks: callbacks_type = default_callbacks, **kwargs, ) -> None: if is_delete_operator_pod is not None: @@ -490,6 +497,17 @@ def __init__( stacklevel=2, ) + if not is_generic_callbacks_supported and callbacks is not None: + warnings.warn( + "The `callbacks` parameter is not supported in this version of cncf.kubernetes." + "Please upgrade to version 7.14.0 or newer.", + UserWarning, + stacklevel=2, + ) + self.callbacks: Any = callbacks + else: + kwargs["callbacks"] = callbacks + super().__init__(**kwargs) self.project_id = project_id self.location = location @@ -581,6 +599,7 @@ def invoke_defer_method(self): in_cluster=self.in_cluster, base_container_name=self.base_container_name, on_finish_action=self.on_finish_action, + callbacks=self.callbacks, ), method_name="execute_complete", kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert}, diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index f91a42ef37ea9..cb79b64c0c394 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -20,12 +20,13 @@ import asyncio import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, Type +import packaging.version from google.cloud.container_v1.types import Operation from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback +from airflow.providers.cncf.kubernetes import __version__ as cnfc_kubernetes_version from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction try: @@ -39,6 +40,20 @@ if TYPE_CHECKING: from datetime import datetime +# TODO: Remove this check when we drop support for cncf-kubernetes < 7.14.0 +callbacks_type: Any +default_callbacks: Any +if packaging.version.parse(cnfc_kubernetes_version) >= packaging.version.parse("7.14.0"): + from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback + + is_generic_callbacks_supported = True + callbacks_type = Type[KubernetesPodOperatorCallback] + default_callbacks = KubernetesPodOperatorCallback +else: + is_generic_callbacks_supported = False + callbacks_type = Any + default_callbacks = None + class GKEStartPodTrigger(KubernetesPodTrigger): """ @@ -82,7 +97,7 @@ def __init__( startup_timeout: int = 120, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, + callbacks: callbacks_type = default_callbacks, *args, **kwargs, ): diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index ba574a1ea6800..40e3d480a2868 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -25,6 +25,7 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection +from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.operators.kubernetes_engine import ( @@ -465,6 +466,7 @@ def test_async_create_pod_should_execute_successfully( with pytest.raises(TaskDeferred) as exc: self.gke_op._cluster_url = CLUSTER_URL self.gke_op._ssl_ca_cert = SSL_CA_CERT + self.gke_op.callbacks = KubernetesPodOperatorCallback self.gke_op.execute(context=mock.MagicMock()) fetch_cluster_info_mock.assert_called_once() assert isinstance(exc.value.trigger, GKEStartPodTrigger) diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index a6b5e2cf2ee6d..8f061a6689905 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -29,7 +29,7 @@ from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import ContainerState -from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger +from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger from airflow.triggers.base import TriggerEvent TRIGGER_GKE_PATH = "airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger" @@ -61,6 +61,17 @@ @pytest.fixture def trigger(): + # This is a workaround for `is_generic_callbacks_supported` check. + # TODO: Remove this workaround after releasing cncf.kubernetes=7.14.0 + import importlib + + from airflow.providers.google.cloud.triggers import kubernetes_engine + + mock.patch("packaging.version.parse", mock.MagicMock(return_value=1)).start() + importlib.reload(kubernetes_engine) + + from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger + return GKEStartPodTrigger( pod_name=POD_NAME, pod_namespace=NAMESPACE, From d9fd4758a0eeb7ba424c07fb5cdc9966995d6113 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 13 Jan 2024 00:12:54 +0100 Subject: [PATCH 11/15] Switch to None default value and bump min cncf-k8s provider in google provider --- .../cncf/kubernetes/operators/pod.py | 57 +++++++++++-------- .../providers/cncf/kubernetes/triggers/pod.py | 36 ++++++------ .../cncf/kubernetes/utils/pod_manager.py | 16 +++--- .../google/cloud/hooks/kubernetes_engine.py | 25 +++++--- .../cloud/operators/kubernetes_engine.py | 21 +------ .../cloud/triggers/kubernetes_engine.py | 24 +++----- airflow/providers/google/provider.yaml | 2 +- .../cncf/kubernetes/triggers/test_pod.py | 3 +- 8 files changed, 89 insertions(+), 95 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/operators/pod.py b/airflow/providers/cncf/kubernetes/operators/pod.py index 838ddb813735e..6518bf2143c2f 100644 --- a/airflow/providers/cncf/kubernetes/operators/pod.py +++ b/airflow/providers/cncf/kubernetes/operators/pod.py @@ -293,7 +293,7 @@ def __init__( is_delete_operator_pod: None | bool = None, termination_message_policy: str = "File", active_deadline_seconds: int | None = None, - callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] | None = None, progress_callback: Callable[[str], None] | None = None, **kwargs, ) -> None: @@ -480,7 +480,8 @@ def hook(self) -> PodOperatorHookProtocol: @cached_property def client(self) -> CoreV1Api: client = self.hook.core_v1_client - self.callbacks.on_sync_client_creation(client=client) + if self.callbacks: + self.callbacks.on_sync_client_creation(client=client) return client def find_pod(self, namespace: str, context: Context, *, exclude_checked: bool = True) -> k8s.V1Pod | None: @@ -558,13 +559,17 @@ def execute_sync(self, context: Context): # get remote pod for use in cleanup methods self.remote_pod = self.find_pod(self.pod.metadata.namespace, context=context) - self.callbacks.on_pod_creation(pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC) + if self.callbacks: + self.callbacks.on_pod_creation( + pod=self.remote_pod, client=self.client, mode=ExecutionMode.SYNC + ) self.await_pod_start(pod=self.pod) - self.callbacks.on_pod_starting( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + if self.callbacks: + self.callbacks.on_pod_starting( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.get_logs: self.pod_manager.fetch_requested_container_logs( @@ -578,12 +583,12 @@ def execute_sync(self, context: Context): self.pod_manager.await_container_completion( pod=self.pod, container_name=self.base_container_name ) - - self.callbacks.on_pod_completion( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + if self.callbacks: + self.callbacks.on_pod_completion( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) if self.do_xcom_push: self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod) @@ -598,7 +603,8 @@ def execute_sync(self, context: Context): pod=pod_to_clean, remote_pod=self.remote_pod, ) - self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) + if self.callbacks: + self.callbacks.on_pod_cleanup(pod=pod_to_clean, client=self.client, mode=ExecutionMode.SYNC) if self.do_xcom_push: return result @@ -608,11 +614,12 @@ def execute_async(self, context: Context): pod_request_obj=self.pod_request_obj, context=context, ) - self.callbacks.on_pod_creation( - pod=self.find_pod(self.pod.metadata.namespace, context=context), - client=self.client, - mode=ExecutionMode.SYNC, - ) + if self.callbacks: + self.callbacks.on_pod_creation( + pod=self.find_pod(self.pod.metadata.namespace, context=context), + client=self.client, + mode=ExecutionMode.SYNC, + ) ti = context["ti"] ti.xcom_push(key="pod_name", value=self.pod.metadata.name) ti.xcom_push(key="pod_namespace", value=self.pod.metadata.namespace) @@ -649,9 +656,10 @@ def execute_complete(self, context: Context, event: dict, **kwargs): event["name"], event["namespace"], ) - self.callbacks.on_operator_resuming( - pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC - ) + if self.callbacks: + self.callbacks.on_operator_resuming( + pod=pod, event=event, client=self.client, mode=ExecutionMode.SYNC + ) if event["status"] in ("error", "failed", "timeout"): # fetch some logs when pod is failed if self.get_logs: @@ -704,7 +712,8 @@ def post_complete_action(self, *, pod, remote_pod, **kwargs): pod=pod, remote_pod=remote_pod, ) - self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) + if self.callbacks: + self.callbacks.on_pod_cleanup(pod=pod, client=self.client, mode=ExecutionMode.SYNC) def cleanup(self, pod: k8s.V1Pod, remote_pod: k8s.V1Pod): istio_enabled = self.is_istio_enabled(remote_pod) diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index aee7c7b72c879..0e74706da2fd9 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -89,7 +89,7 @@ def __init__( startup_check_interval: int = 1, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] | None = None, ): super().__init__() self.pod_name = pod_name @@ -162,15 +162,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.debug("Container %s status: %s", self.base_container_name, container_state) if container_state == ContainerState.TERMINATED: - if not _is_starting_callback_called: + if self.callbacks and not _is_starting_callback_called: self.callbacks.on_pod_starting( pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC, ) - self.callbacks.on_pod_completion( - pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC - ) + if self.callbacks: + self.callbacks.on_pod_completion( + pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC + ) yield TriggerEvent( { "name": self.pod_name, @@ -203,7 +204,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Sleeping for %s seconds.", self.startup_check_interval) await asyncio.sleep(self.startup_check_interval) else: - if not _is_starting_callback_called: + if self.callbacks and not _is_starting_callback_called: # if the trigger fails and re-run on a different triggerer, this callback could # be called again self.callbacks.on_pod_starting( @@ -215,15 +216,16 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) else: - if not _is_starting_callback_called: + if self.callbacks and not _is_starting_callback_called: self.callbacks.on_pod_starting( pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC, ) - self.callbacks.on_pod_completion( - pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC - ) + if self.callbacks: + self.callbacks.on_pod_completion( + pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC + ) yield TriggerEvent( { "name": self.pod_name, @@ -247,11 +249,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] name=self.pod_name, namespace=self.pod_namespace, ) - self.callbacks.on_pod_cleanup( - pod=await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace), - client=self._get_async_hook().core_v1_client, - mode=ExecutionMode.ASYNC, - ) + if self.callbacks: + self.callbacks.on_pod_cleanup( + pod=await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace), + client=self._get_async_hook().core_v1_client, + mode=ExecutionMode.ASYNC, + ) yield TriggerEvent( { "name": self.pod_name, @@ -280,7 +283,8 @@ def _get_async_hook(self) -> AsyncKubernetesHook: config_file=self.config_file, cluster_context=self.cluster_context, ) - self.callbacks.on_async_client_creation(client=_hook.core_v1_client) + if self.callbacks: + self.callbacks.on_async_client_creation(client=_hook.core_v1_client) return _hook @cached_property diff --git a/airflow/providers/cncf/kubernetes/utils/pod_manager.py b/airflow/providers/cncf/kubernetes/utils/pod_manager.py index d08677fad1b28..0e736daa6aa76 100644 --- a/airflow/providers/cncf/kubernetes/utils/pod_manager.py +++ b/airflow/providers/cncf/kubernetes/utils/pod_manager.py @@ -289,7 +289,7 @@ class PodManager(LoggingMixin): def __init__( self, kube_client: client.CoreV1Api, - callbacks: type[KubernetesPodOperatorCallback] = KubernetesPodOperatorCallback, + callbacks: type[KubernetesPodOperatorCallback] | None = None, progress_callback: Callable[[str], None] | None = None, ): """ @@ -450,9 +450,10 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None for line in progress_callback_lines: if self._progress_callback: self._progress_callback(line) - self._callbacks.progress_callback( - line=line, client=self._client, mode=ExecutionMode.SYNC - ) + if self._callbacks: + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp message_to_log = message @@ -466,9 +467,10 @@ def consume_logs(*, since_time: DateTime | None = None) -> tuple[DateTime | None for line in progress_callback_lines: if self._progress_callback: self._progress_callback(line) - self._callbacks.progress_callback( - line=line, client=self._client, mode=ExecutionMode.SYNC - ) + if self._callbacks: + self._callbacks.progress_callback( + line=line, client=self._client, mode=ExecutionMode.SYNC + ) self.log.info("[%s] %s", container_name, message_to_log) last_captured_timestamp = message_timestamp except TimeoutError as e: diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index df0bd86025224..5e29b2f40bb0f 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -444,14 +444,23 @@ def __init__(self, cluster_url: str, ssl_ca_cert: str, **kwargs) -> None: super().__init__(cluster_url=cluster_url, ssl_ca_cert=ssl_ca_cert, **kwargs) @contextlib.asynccontextmanager - async def get_conn(self, token: Token) -> async_client.ApiClient: # type: ignore[override] - kube_client = None - try: - kube_client = await self._load_config(token) - yield kube_client - finally: - if kube_client is not None: - await kube_client.close() + async def get_conn(self, token: Token | None = None) -> async_client.ApiClient: # type: ignore[override] + async def _get_conn(_token: Token) -> async_client.ApiClient: + _kube_client = None + try: + _kube_client = await self._load_config(_token) + yield _kube_client + finally: + if _kube_client is not None: + await _kube_client.close() + + if token is None: + async with Token(scopes=self.scopes) as token: + async with _get_conn(token) as kube_client: + yield kube_client + else: + async with _get_conn(token) as kube_client: + yield kube_client async def _load_config(self, token: Token) -> async_client.ApiClient: configuration = self._get_config() diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index ebf06b6251831..9b12d8fdf76ba 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -35,13 +35,7 @@ KubernetesEnginePodLink, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator -from airflow.providers.google.cloud.triggers.kubernetes_engine import ( - GKEOperationTrigger, - GKEStartPodTrigger, - callbacks_type, - default_callbacks, - is_generic_callbacks_supported, -) +from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger from airflow.utils.timezone import utcnow if TYPE_CHECKING: @@ -456,7 +450,6 @@ def __init__( regional: bool | None = None, on_finish_action: str | None = None, is_delete_operator_pod: bool | None = None, - callbacks: callbacks_type = default_callbacks, **kwargs, ) -> None: if is_delete_operator_pod is not None: @@ -490,18 +483,6 @@ def __init__( AirflowProviderDeprecationWarning, stacklevel=2, ) - - if not is_generic_callbacks_supported and callbacks is not None: - warnings.warn( - "The `callbacks` parameter is not supported in this version of cncf.kubernetes." - "Please upgrade to version 7.14.0 or newer.", - UserWarning, - stacklevel=2, - ) - self.callbacks: Any = callbacks - else: - kwargs["callbacks"] = callbacks - super().__init__(**kwargs) self.project_id = project_id self.location = location diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index ade59c5e35274..685463f429160 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -20,13 +20,12 @@ import asyncio import warnings from functools import cached_property -from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence, Type +from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence -import packaging.version from google.cloud.container_v1.types import Operation +from kubernetes_asyncio.client import CoreV1Api from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.cncf.kubernetes import __version__ as cnfc_kubernetes_version from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.hooks.kubernetes_engine import GKEAsyncHook, GKEPodAsyncHook @@ -35,20 +34,8 @@ if TYPE_CHECKING: from datetime import datetime -# TODO: Remove this check when we drop support for cncf-kubernetes < 7.14.0 -callbacks_type: Any -default_callbacks: Any -if packaging.version.parse(cnfc_kubernetes_version) >= packaging.version.parse("7.14.0"): from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback - is_generic_callbacks_supported = True - callbacks_type = Type[KubernetesPodOperatorCallback] - default_callbacks = KubernetesPodOperatorCallback -else: - is_generic_callbacks_supported = False - callbacks_type = Any - default_callbacks = None - class GKEStartPodTrigger(KubernetesPodTrigger): """ @@ -92,7 +79,7 @@ def __init__( startup_timeout: int = 120, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: callbacks_type = default_callbacks, + callbacks: type[KubernetesPodOperatorCallback] | None = None, *args, **kwargs, ): @@ -154,10 +141,13 @@ def serialize(self) -> tuple[str, dict[str, Any]]: @cached_property def hook(self) -> GKEPodAsyncHook: # type: ignore[override] - return GKEPodAsyncHook( + _hook = GKEPodAsyncHook( cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, ) + if self.callbacks: + self.callbacks.on_async_client_creation(client=CoreV1Api(_hook.get_conn())) + return _hook class GKEOperationTrigger(BaseTrigger): diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index eb8a728211c2e..1dcd2e21a34df 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -151,7 +151,7 @@ additional-extras: - apache-beam[gcp] - name: cncf.kubernetes dependencies: - - apache-airflow-providers-cncf-kubernetes>=7.2.0 + - apache-airflow-providers-cncf-kubernetes>=7.14.0 - name: leveldb dependencies: - plyvel diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 6b070ef2e6147..9deac5b60ff65 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -26,7 +26,6 @@ import pytest from kubernetes.client import models as k8s -from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.triggers.pod import ContainerState, KubernetesPodTrigger from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase from airflow.triggers.base import TriggerEvent @@ -95,7 +94,7 @@ def test_serialize(self, trigger): "trigger_start_time": TRIGGER_START_TIME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": ON_FINISH_ACTION == "delete_pod", - "callbacks": KubernetesPodOperatorCallback, + "callbacks": None, } @pytest.mark.asyncio From 39ece30d020552ffb8c2ce06d622d2790c6015c9 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 13 Jan 2024 01:02:11 +0100 Subject: [PATCH 12/15] Fix tests --- dev/breeze/tests/test_packages.py | 2 +- .../cloud/triggers/test_kubernetes_engine.py | 22 +++---------------- 2 files changed, 4 insertions(+), 20 deletions(-) diff --git a/dev/breeze/tests/test_packages.py b/dev/breeze/tests/test_packages.py index 017a71e9d63e6..58032aa2a6c80 100644 --- a/dev/breeze/tests/test_packages.py +++ b/dev/breeze/tests/test_packages.py @@ -174,7 +174,7 @@ def test_get_package_extras(): "amazon": ["apache-airflow-providers-amazon>=2.6.0"], "apache.beam": ["apache-airflow-providers-apache-beam", "apache-beam[gcp]"], "apache.cassandra": ["apache-airflow-providers-apache-cassandra"], - "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.2.0"], + "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.14.0"], "common.sql": ["apache-airflow-providers-common-sql"], "facebook": ["apache-airflow-providers-facebook>=2.2.0"], "leveldb": ["plyvel"], diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 52c08a1d1a182..637380cf07ead 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -27,9 +27,8 @@ from google.cloud.container_v1.types import Operation from kubernetes.client import models as k8s -from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.triggers.kubernetes_pod import ContainerState -from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger +from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEOperationTrigger, GKEStartPodTrigger from airflow.triggers.base import TriggerEvent TRIGGER_GKE_PATH = "airflow.providers.google.cloud.triggers.kubernetes_engine.GKEStartPodTrigger" @@ -61,17 +60,6 @@ @pytest.fixture def trigger(): - # This is a workaround for `is_generic_callbacks_supported` check. - # TODO: Remove this workaround after releasing cncf.kubernetes=7.14.0 - import importlib - - from airflow.providers.google.cloud.triggers import kubernetes_engine - - mock.patch("packaging.version.parse", mock.MagicMock(return_value=1)).start() - importlib.reload(kubernetes_engine) - - from airflow.providers.google.cloud.triggers.kubernetes_engine import GKEStartPodTrigger - return GKEStartPodTrigger( pod_name=POD_NAME, pod_namespace=NAMESPACE, @@ -113,7 +101,7 @@ def test_serialize_should_execute_successfully(self, trigger): "base_container_name": BASE_CONTAINER_NAME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": SHOULD_DELETE_POD, - "callbacks": KubernetesPodOperatorCallback, + "callbacks": None, } @pytest.mark.asyncio @@ -246,11 +234,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully( """ Test that GKEStartPodTrigger fires the correct event in case if the task was cancelled. """ - - async def async_mock(): - return mock.MagicMock() - - mock_hook.get_pod.side_effect = [CancelledError(), async_mock()] + mock_hook.get_pod.side_effect = CancelledError() mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) From 1d133d7afe2ed6312224527071aa7cdc69a7ce72 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 13 Jan 2024 18:15:31 +0100 Subject: [PATCH 13/15] Reduce check intervals to avoid killing the asyncio task --- tests/providers/cncf/kubernetes/triggers/test_pod.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 9deac5b60ff65..1d765d9bb453f 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -401,7 +401,7 @@ async def test_callbacks(self, mock_hook, mock_method): pod_namespace=NAMESPACE, base_container_name=BASE_CONTAINER_NAME, kubernetes_conn_id=CONN_ID, - poll_interval=POLL_INTERVAL, + poll_interval=0, cluster_context=CLUSTER_CONTEXT, config_file=CONFIG_FILE, in_cluster=IN_CLUSTER, @@ -410,6 +410,7 @@ async def test_callbacks(self, mock_hook, mock_method): trigger_start_time=TRIGGER_START_TIME, on_finish_action=ON_FINISH_ACTION, callbacks=MockKubernetesPodOperatorCallback, + startup_check_interval=0, ) await k.run().asend(None) From 79354e8a09997d44bad9e08c087bb7e2434f82c3 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 20 Jan 2024 01:50:40 +0100 Subject: [PATCH 14/15] Revert async callbacks --- .../providers/cncf/kubernetes/callbacks.py | 14 ++--- .../providers/cncf/kubernetes/triggers/pod.py | 45 +------------- .../google/cloud/hooks/kubernetes_engine.py | 25 +++----- .../cloud/operators/kubernetes_engine.py | 2 +- .../cloud/triggers/kubernetes_engine.py | 11 +--- airflow/providers/google/provider.yaml | 2 +- .../operators.rst | 22 +------ .../cncf/kubernetes/triggers/test_pod.py | 62 +------------------ .../cloud/operators/test_kubernetes_engine.py | 12 +--- .../cloud/triggers/test_kubernetes_engine.py | 1 - 10 files changed, 22 insertions(+), 174 deletions(-) diff --git a/airflow/providers/cncf/kubernetes/callbacks.py b/airflow/providers/cncf/kubernetes/callbacks.py index 0ed2fb0b2116f..4baef440deec6 100644 --- a/airflow/providers/cncf/kubernetes/callbacks.py +++ b/airflow/providers/cncf/kubernetes/callbacks.py @@ -33,7 +33,11 @@ class ExecutionMode(str, Enum): class KubernetesPodOperatorCallback: - """`KubernetesPodOperator` callbacks methods.""" + """`KubernetesPodOperator` callbacks methods. + + Currently, the callbacks methods are not called in the async mode, this support will be added + in the future. + """ @staticmethod def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: @@ -43,14 +47,6 @@ def on_sync_client_creation(*, client: k8s.CoreV1Api, **kwargs) -> None: """ pass - @staticmethod - def on_async_client_creation(*, client: async_k8s.CoreV1Api, **kwargs) -> None: - """Callback method called after creating the async client. - - :param client: the created `kubernetes_asyncio.client.CoreV1Api` client. - """ - pass - @staticmethod def on_pod_creation(*, pod: k8s.V1Pod, client: client_type, mode: str, **kwargs) -> None: """Callback method called after creating the pod. diff --git a/airflow/providers/cncf/kubernetes/triggers/pod.py b/airflow/providers/cncf/kubernetes/triggers/pod.py index 0e74706da2fd9..3dd9eb173ca57 100644 --- a/airflow/providers/cncf/kubernetes/triggers/pod.py +++ b/airflow/providers/cncf/kubernetes/triggers/pod.py @@ -26,7 +26,6 @@ from typing import TYPE_CHECKING, Any, AsyncIterator from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode, KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction, PodPhase from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -89,7 +88,6 @@ def __init__( startup_check_interval: int = 1, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: type[KubernetesPodOperatorCallback] | None = None, ): super().__init__() self.pod_name = pod_name @@ -104,7 +102,6 @@ def __init__( self.get_logs = get_logs self.startup_timeout = startup_timeout self.startup_check_interval = startup_check_interval - self.callbacks = callbacks if should_delete_pod is not None: warnings.warn( @@ -140,14 +137,12 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "trigger_start_time": self.trigger_start_time, "should_delete_pod": self.should_delete_pod, "on_finish_action": self.on_finish_action.value, - "callbacks": self.callbacks, }, ) async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] """Get current pod status and yield a TriggerEvent.""" self.log.info("Checking pod %r in namespace %r.", self.pod_name, self.pod_namespace) - _is_starting_callback_called = False try: while True: pod = await self.hook.get_pod( @@ -162,16 +157,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.debug("Container %s status: %s", self.base_container_name, container_state) if container_state == ContainerState.TERMINATED: - if self.callbacks and not _is_starting_callback_called: - self.callbacks.on_pod_starting( - pod=pod, - client=self._get_async_hook().core_v1_client, - mode=ExecutionMode.ASYNC, - ) - if self.callbacks: - self.callbacks.on_pod_completion( - pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC - ) yield TriggerEvent( { "name": self.pod_name, @@ -204,28 +189,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] self.log.info("Sleeping for %s seconds.", self.startup_check_interval) await asyncio.sleep(self.startup_check_interval) else: - if self.callbacks and not _is_starting_callback_called: - # if the trigger fails and re-run on a different triggerer, this callback could - # be called again - self.callbacks.on_pod_starting( - pod=pod, - client=self._get_async_hook().core_v1_client, - mode=ExecutionMode.ASYNC, - ) - _is_starting_callback_called = True self.log.info("Sleeping for %s seconds.", self.poll_interval) await asyncio.sleep(self.poll_interval) else: - if self.callbacks and not _is_starting_callback_called: - self.callbacks.on_pod_starting( - pod=pod, - client=self._get_async_hook().core_v1_client, - mode=ExecutionMode.ASYNC, - ) - if self.callbacks: - self.callbacks.on_pod_completion( - pod=pod, client=self._get_async_hook().core_v1_client, mode=ExecutionMode.ASYNC - ) yield TriggerEvent( { "name": self.pod_name, @@ -249,12 +215,6 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] name=self.pod_name, namespace=self.pod_namespace, ) - if self.callbacks: - self.callbacks.on_pod_cleanup( - pod=await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace), - client=self._get_async_hook().core_v1_client, - mode=ExecutionMode.ASYNC, - ) yield TriggerEvent( { "name": self.pod_name, @@ -277,15 +237,12 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override] def _get_async_hook(self) -> AsyncKubernetesHook: # TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider. - _hook = AsyncKubernetesHook( + return AsyncKubernetesHook( conn_id=self.kubernetes_conn_id, in_cluster=self.in_cluster, config_file=self.config_file, cluster_context=self.cluster_context, ) - if self.callbacks: - self.callbacks.on_async_client_creation(client=_hook.core_v1_client) - return _hook @cached_property def hook(self) -> AsyncKubernetesHook: diff --git a/airflow/providers/google/cloud/hooks/kubernetes_engine.py b/airflow/providers/google/cloud/hooks/kubernetes_engine.py index 55f0a1a3ee4dd..0e62b990da921 100644 --- a/airflow/providers/google/cloud/hooks/kubernetes_engine.py +++ b/airflow/providers/google/cloud/hooks/kubernetes_engine.py @@ -446,23 +446,14 @@ def __init__(self, cluster_url: str, ssl_ca_cert: str, **kwargs) -> None: super().__init__(cluster_url=cluster_url, ssl_ca_cert=ssl_ca_cert, **kwargs) @contextlib.asynccontextmanager - async def get_conn(self, token: Token | None = None) -> async_client.ApiClient: # type: ignore[override] - async def _get_conn(_token: Token) -> async_client.ApiClient: - _kube_client = None - try: - _kube_client = await self._load_config(_token) - yield _kube_client - finally: - if _kube_client is not None: - await _kube_client.close() - - if token is None: - async with Token(scopes=self.scopes) as token: - async with _get_conn(token) as kube_client: - yield kube_client - else: - async with _get_conn(token) as kube_client: - yield kube_client + async def get_conn(self, token: Token) -> async_client.ApiClient: # type: ignore[override] + kube_client = None + try: + kube_client = await self._load_config(token) + yield kube_client + finally: + if kube_client is not None: + await kube_client.close() async def _load_config(self, token: Token) -> async_client.ApiClient: configuration = self._get_config() diff --git a/airflow/providers/google/cloud/operators/kubernetes_engine.py b/airflow/providers/google/cloud/operators/kubernetes_engine.py index f4b08eb9e828b..e5ca3c271b938 100644 --- a/airflow/providers/google/cloud/operators/kubernetes_engine.py +++ b/airflow/providers/google/cloud/operators/kubernetes_engine.py @@ -485,6 +485,7 @@ def __init__( AirflowProviderDeprecationWarning, stacklevel=2, ) + super().__init__(**kwargs) self.project_id = project_id self.location = location @@ -576,7 +577,6 @@ def invoke_defer_method(self): in_cluster=self.in_cluster, base_container_name=self.base_container_name, on_finish_action=self.on_finish_action, - callbacks=self.callbacks, ), method_name="execute_complete", kwargs={"cluster_url": self._cluster_url, "ssl_ca_cert": self._ssl_ca_cert}, diff --git a/airflow/providers/google/cloud/triggers/kubernetes_engine.py b/airflow/providers/google/cloud/triggers/kubernetes_engine.py index c038d7f3c8b00..da068dcfc3d02 100644 --- a/airflow/providers/google/cloud/triggers/kubernetes_engine.py +++ b/airflow/providers/google/cloud/triggers/kubernetes_engine.py @@ -23,7 +23,6 @@ from typing import TYPE_CHECKING, Any, AsyncIterator, Sequence from google.cloud.container_v1.types import Operation -from kubernetes_asyncio.client import CoreV1Api from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger @@ -34,8 +33,6 @@ if TYPE_CHECKING: from datetime import datetime - from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback - class GKEStartPodTrigger(KubernetesPodTrigger): """ @@ -79,7 +76,6 @@ def __init__( startup_timeout: int = 120, on_finish_action: str = "delete_pod", should_delete_pod: bool | None = None, - callbacks: type[KubernetesPodOperatorCallback] | None = None, *args, **kwargs, ): @@ -100,7 +96,6 @@ def __init__( self.in_cluster = in_cluster self.get_logs = get_logs self.startup_timeout = startup_timeout - self.callbacks = callbacks if should_delete_pod is not None: warnings.warn( @@ -136,19 +131,15 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "base_container_name": self.base_container_name, "should_delete_pod": self.should_delete_pod, "on_finish_action": self.on_finish_action.value, - "callbacks": self.callbacks, }, ) @cached_property def hook(self) -> GKEPodAsyncHook: # type: ignore[override] - _hook = GKEPodAsyncHook( + return GKEPodAsyncHook( cluster_url=self._cluster_url, ssl_ca_cert=self._ssl_ca_cert, ) - if self.callbacks: - self.callbacks.on_async_client_creation(client=CoreV1Api(_hook.get_conn())) - return _hook class GKEOperationTrigger(BaseTrigger): diff --git a/airflow/providers/google/provider.yaml b/airflow/providers/google/provider.yaml index 1dcd2e21a34df..eb8a728211c2e 100644 --- a/airflow/providers/google/provider.yaml +++ b/airflow/providers/google/provider.yaml @@ -151,7 +151,7 @@ additional-extras: - apache-beam[gcp] - name: cncf.kubernetes dependencies: - - apache-airflow-providers-cncf-kubernetes>=7.14.0 + - apache-airflow-providers-cncf-kubernetes>=7.2.0 - name: leveldb dependencies: - plyvel diff --git a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst index 4ed03bba5650c..f24fd61602efa 100644 --- a/docs/apache-airflow-providers-cncf-kubernetes/operators.rst +++ b/docs/apache-airflow-providers-cncf-kubernetes/operators.rst @@ -207,7 +207,6 @@ parameter. The following callbacks are supported: * on_sync_client_creation: called after creating the sync client -* on_async_client_creation: called after creating the async client * on_pod_creation: called after creating the pod * on_pod_starting: called after the pod starts * on_pod_completion: called when the pod completes @@ -215,6 +214,8 @@ The following callbacks are supported: * on_operator_resuming: when resuming the task from deferred state * progress_callback: called on each line of containers logs +Currently, the callbacks methods are not called in the async mode, this support will be added in the future. + Example: ~~~~~~~~ .. code-block:: python @@ -229,7 +230,6 @@ Example: class MyCallback(KubernetesPodOperatorCallback): @staticmethod def on_pod_creation(*, pod: k8s.V1Pod, client: k8s.CoreV1Api, mode: str, **kwargs) -> None: - # currently, the pod is always created when the task is running, so the mode is always "sync" client.create_namespaced_service( namespace=pod.metadata.namespace, body=k8s.V1Service( @@ -260,24 +260,6 @@ Example: ), ) - @staticmethod - def on_pod_starting( - *, pod: k8s.V1Pod, client: k8s.CoreV1Api | async_k8s.CoreV1Api, mode: str, **kwargs - ) -> None: - # this callback can be called in sync or async mode, so we need to handle both cases and avoid blocking the event loop - import asyncio - - def _some_sync_function(): - ... - - async def _some_async_function(): - ... - - if mode == "sync": - _some_sync_function() - else: - asyncio.get_event_loop().run_until_complete(_some_async_function()) - k = KubernetesPodOperator( task_id="test_callback", diff --git a/tests/providers/cncf/kubernetes/triggers/test_pod.py b/tests/providers/cncf/kubernetes/triggers/test_pod.py index 1d765d9bb453f..9c016ea8cfb9a 100644 --- a/tests/providers/cncf/kubernetes/triggers/test_pod.py +++ b/tests/providers/cncf/kubernetes/triggers/test_pod.py @@ -47,10 +47,6 @@ ON_FINISH_ACTION = "delete_pod" -async def async_mock(): - return mock.MagicMock() - - @pytest.fixture def trigger(): return KubernetesPodTrigger( @@ -94,7 +90,6 @@ def test_serialize(self, trigger): "trigger_start_time": TRIGGER_START_TIME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": ON_FINISH_ACTION == "delete_pod", - "callbacks": None, } @pytest.mark.asyncio @@ -224,7 +219,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_and Test that KubernetesPodTrigger fires the correct event in case if the task was cancelled. """ - mock_hook.get_pod.side_effect = [CancelledError(), async_mock()] + mock_hook.get_pod.side_effect = CancelledError() mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) @@ -270,7 +265,7 @@ async def test_logging_in_trigger_when_cancelled_should_execute_successfully_wit Test that KubernetesPodTrigger fires the correct event if the task was cancelled. """ - mock_hook.get_pod.side_effect = [CancelledError(), async_mock()] + mock_hook.get_pod.side_effect = CancelledError() mock_hook.read_logs.return_value = self._mock_pod_result(mock.MagicMock()) mock_hook.delete_pod.return_value = self._mock_pod_result(mock.MagicMock()) @@ -376,56 +371,3 @@ async def test_run_loop_return_timeout_event( ) == actual ) - - @pytest.mark.asyncio - @mock.patch(f"{TRIGGER_PATH}.define_container_state") - @mock.patch(f"{TRIGGER_PATH}._get_async_hook") - async def test_callbacks(self, mock_hook, mock_method): - from airflow.providers.cncf.kubernetes.callbacks import ExecutionMode - - from ..test_callbacks import MockKubernetesPodOperatorCallback, MockWrapper - - MockWrapper.reset() - mock_callbacks = MockWrapper.mock_callbacks - - pods_mock = [ - self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.PENDING))), - self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.RUNNING))), - self._mock_pod_result(mock.MagicMock(status=mock.MagicMock(phase=PodPhase.SUCCEEDED))), - ] - mock_hook.return_value.get_pod.side_effect = pods_mock - mock_method.side_effect = [ContainerState.WAITING, ContainerState.RUNNING, ContainerState.TERMINATED] - - k = KubernetesPodTrigger( - pod_name=POD_NAME, - pod_namespace=NAMESPACE, - base_container_name=BASE_CONTAINER_NAME, - kubernetes_conn_id=CONN_ID, - poll_interval=0, - cluster_context=CLUSTER_CONTEXT, - config_file=CONFIG_FILE, - in_cluster=IN_CLUSTER, - get_logs=GET_LOGS, - startup_timeout=STARTUP_TIMEOUT_SECS, - trigger_start_time=TRIGGER_START_TIME, - on_finish_action=ON_FINISH_ACTION, - callbacks=MockKubernetesPodOperatorCallback, - startup_check_interval=0, - ) - await k.run().asend(None) - - # check on_pod_starting callback - mock_callbacks.on_pod_starting.assert_called_once() - assert mock_callbacks.on_pod_starting.call_args.kwargs == { - "client": k._get_async_hook().core_v1_client, - "mode": ExecutionMode.ASYNC, - "pod": pods_mock[1].result(), - } - - # check on_pod_completion callback - mock_callbacks.on_pod_completion.assert_called_once() - assert mock_callbacks.on_pod_completion.call_args.kwargs == { - "client": k._get_async_hook().core_v1_client, - "mode": ExecutionMode.ASYNC, - "pod": pods_mock[2].result(), - } diff --git a/tests/providers/google/cloud/operators/test_kubernetes_engine.py b/tests/providers/google/cloud/operators/test_kubernetes_engine.py index f588dad57c403..7805804485682 100644 --- a/tests/providers/google/cloud/operators/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/operators/test_kubernetes_engine.py @@ -25,7 +25,6 @@ from airflow.exceptions import AirflowException, TaskDeferred from airflow.models import Connection -from airflow.providers.cncf.kubernetes.callbacks import KubernetesPodOperatorCallback from airflow.providers.cncf.kubernetes.operators.pod import KubernetesPodOperator from airflow.providers.cncf.kubernetes.utils.pod_manager import OnFinishAction from airflow.providers.google.cloud.operators.kubernetes_engine import ( @@ -441,8 +440,6 @@ def setup_method(self): self.gke_op._ssl_ca_cert = SSL_CA_CERT @mock.patch.dict(os.environ, {}) - @mock.patch(KUB_OP_PATH.format("client")) - @mock.patch(KUB_OP_PATH.format("find_pod")) @mock.patch(KUB_OP_PATH.format("build_pod_request_obj")) @mock.patch(KUB_OP_PATH.format("get_or_create_pod")) @mock.patch( @@ -451,13 +448,7 @@ def setup_method(self): ) @mock.patch(f"{GKE_OP_PATH}.fetch_cluster_info") def test_async_create_pod_should_execute_successfully( - self, - fetch_cluster_info_mock, - get_con_mock, - mocked_pod, - mocked_pod_obj, - mocked_found_pod, - mocked_client, + self, fetch_cluster_info_mock, get_con_mock, mocked_pod, mocked_pod_obj ): """ Asserts that a task is deferred and the GKEStartPodTrigger will be fired @@ -466,7 +457,6 @@ def test_async_create_pod_should_execute_successfully( with pytest.raises(TaskDeferred) as exc: self.gke_op._cluster_url = CLUSTER_URL self.gke_op._ssl_ca_cert = SSL_CA_CERT - self.gke_op.callbacks = KubernetesPodOperatorCallback self.gke_op.execute(context=mock.MagicMock()) fetch_cluster_info_mock.assert_called_once() assert isinstance(exc.value.trigger, GKEStartPodTrigger) diff --git a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py index 637380cf07ead..65bc45c41550c 100644 --- a/tests/providers/google/cloud/triggers/test_kubernetes_engine.py +++ b/tests/providers/google/cloud/triggers/test_kubernetes_engine.py @@ -101,7 +101,6 @@ def test_serialize_should_execute_successfully(self, trigger): "base_container_name": BASE_CONTAINER_NAME, "on_finish_action": ON_FINISH_ACTION, "should_delete_pod": SHOULD_DELETE_POD, - "callbacks": None, } @pytest.mark.asyncio From f56a925407533fabc8d576ff743a3e404c1391b6 Mon Sep 17 00:00:00 2001 From: hussein-awala Date: Sat, 20 Jan 2024 01:56:57 +0100 Subject: [PATCH 15/15] fix breeze tests --- dev/breeze/tests/test_packages.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev/breeze/tests/test_packages.py b/dev/breeze/tests/test_packages.py index 58032aa2a6c80..017a71e9d63e6 100644 --- a/dev/breeze/tests/test_packages.py +++ b/dev/breeze/tests/test_packages.py @@ -174,7 +174,7 @@ def test_get_package_extras(): "amazon": ["apache-airflow-providers-amazon>=2.6.0"], "apache.beam": ["apache-airflow-providers-apache-beam", "apache-beam[gcp]"], "apache.cassandra": ["apache-airflow-providers-apache-cassandra"], - "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.14.0"], + "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.2.0"], "common.sql": ["apache-airflow-providers-common-sql"], "facebook": ["apache-airflow-providers-facebook>=2.2.0"], "leveldb": ["plyvel"],